# Classical HII region expansion

## TOC
* [Packages & Settings](#packages)
* [Spectrum binning](#spectrum-binning)
* [Snapshots](#snapshots)
* [Loading Snapshots](#loading)
* [Maps](#maps)
* [Lineouts](#lineouts)
* [I-front position](#pos-vel)

## Packages & Settings  <a class="anchor" id="packages"></a>

In [None]:
from pyrhyme import PyRhyme

from pathlib import Path

# !pip install numpy
import numpy as np

# !pip install astropy
from astropy import units as U
from astropy import constants as C
from astropy.cosmology import WMAP7

# !pip install scipy
from scipy.io import FortranFile
from scipy.interpolate import interp1d

# !pip install plotly
import plotly
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
from plotly.subplots import make_subplots
pio.templates.default = "simple_white"

from math import pi

## Spectrum binning <a class="anchor" id="spectrum-binning"></a>

In [None]:
def blackbody(e_Ryd, T_K):
    return (2 / (C.h * C.c)**2 * e_Ryd**3 / (np.exp(e_Ryd/(C.k_B * T_K)) - 1)).to(U.erg / U.cm**2)

Ryd = 13.6 * U.eV
T = 1e5 * U.K

energies = np.logspace(0, 1, 100)
powers = [blackbody(e * Ryd, T).value for e in energies]

fig = go.Figure()
fig.add_trace(go.Scatter(
    x=energies, y=powers, mode='lines', line=dict(color='black', width=2)
))

fig.update_layout(
    width=500, height=500,
    xaxis=dict(mirror=True, type='linear', title='E [Ryd]',),
    yaxis=dict(mirror=True, type='linear', title='P [erg cm<sup>-2</sup>]'),
)

fig.show()

## Snapshots  <a class="anchor" id="snapshots"></a>

Run
```bash
wget -r –level=0 -E –ignore-length -x -k -p -erobots=off -np -N https://astronomy.sussex.ac.uk/\~iti20/RT_comparison_project/RT_workshop_data/T5_results/
```
to download Iliev-5 test snapshots!

In [None]:
ILIEV_5_DIR = 'astronomy.sussex.ac.uk/~iti20/RT_comparison_project/RT_workshop_data/T5_results'

IMAGE_DIR = './images'
Path(IMAGE_DIR).mkdir(parents=True, exist_ok=True)
TIMES = [10, 30, 100, 200, 500]

BINARIES = {
    'CAPREOLE+C2-RAY': {
        'row': 1, 'col': 3,
        'orig': {t: {'path': f"{ILIEV_5_DIR}/c2ray{i+1}.bin"} for i, t in enumerate(TIMES)},
        'add': {t: {'path': f"{ILIEV_5_DIR}/c2ray_add{i+1}.bin"} for i, t in enumerate(TIMES)},
    },
    'Enzo': {
        'row': 1, 'col': 4,
        'orig': {t: {'path': f"{ILIEV_5_DIR}/enzo{i+1}.bin"} for i, t in enumerate(TIMES)},
        'add': {t: {'path': f"{ILIEV_5_DIR}/enzo_add{i+1}.bin"} for i, t in enumerate(TIMES)},
    },
#     'Flash': {
#         'row': 2, 'col': 3,
#         'orig': {t: {'path': f"{ILIEV_5_DIR}/flash{i+1}.bin"} for i, t in enumerate(TIMES)},
#         'add': {t: {'path': f"{ILIEV_5_DIR}/flash_add{i+1}.bin"} for i, t in enumerate(TIMES)},
#     },
#     'HART': {
#         'row': 2, 'col': 4,
#         'orig': {t: {'path': f"{ILIEV_5_DIR}/Gnedin{i+1}.bin"} for i, t in enumerate(TIMES)},
#         'add': {t: {'path': f"{ILIEV_5_DIR}/Gnedin_add{i+1}.bin"} for i, t in enumerate(TIMES)},
#     },
#     'LICORICE': {
#         'row': 3, 'col': 3,
#         'orig': {t: {'path': f"{ILIEV_5_DIR}/Licorice{i+1}.bin"} for i, t in enumerate(TIMES)},
#         'add': {t: {'path': f"{ILIEV_5_DIR}/Licorice_add{i+1}.bin"} for i, t in enumerate(TIMES)},
#     },
    'RSPH': {
        'row': 3, 'col': 4,
        'orig': {t: {'path': f"{ILIEV_5_DIR}/susa{i+1}.bin"} for i, t in enumerate(TIMES)},
        'add': {t: {'path': f"{ILIEV_5_DIR}/susa_add{i+1}.bin"} for i, t in enumerate(TIMES)},
    },
    'RH1D': {
        'row': 4, 'col': 3,
        'orig': {t: {'path': f"{ILIEV_5_DIR}/RH1D{i+1}.bin"} for i, t in enumerate(TIMES)},
        'add': {t: {'path': f"{ILIEV_5_DIR}/RH1D_add{i+1}.bin"} for i, t in enumerate(TIMES)},
    },
#     'ZEUS-MP': {
#         'row': 4, 'col': 4,
#         'orig': {t: {'path': f"{ILIEV_5_DIR}/Zeus{i+1}.bin"} for i, t in enumerate(TIMES)},
#         'add': {t: {'path': f"{ILIEV_5_DIR}/Zeus_add{i+1}.bin"} for i, t in enumerate(TIMES)},
#     },
}

RHYME = {
    'Rhyme 3': {
        'mapPlot': False,
        0: {'path': './Iliev-5-CaseA-3bins/Iliev-5-CaseA-3bins-000000.chombo.h5'}
    },
    'Rhyme 10': {
        'mapPlot': False,
        0: {'path': './Iliev-5-CaseA-10bins/Iliev-5-CaseA-10bins-000000.chombo.h5'}
    },
    'Rhyme 30': {
        'mapPlot': False,
        0: {'path': './Iliev-5-CaseA-30bins/Iliev-5-CaseA-30bins-000000.chombo.h5'}
    },
    'Rhyme 100': {
        'mapPlot': False,
        0: {'path': './Iliev-5-CaseA-100bins/Iliev-5-CaseA-100bins-000000.chombo.h5'}
    },
    'Rhyme Log 3': {
        'mapPlot': True,
        0: {'path': './Iliev-5-CaseA-logBins-3bins/Iliev-5-CaseA-logBins-3bins-000000.chombo.h5'}
    },
    'Rhyme Log 10': {
        'mapPlot': False,
        0: {'path': './Iliev-5-CaseA-logBins-10bins/Iliev-5-CaseA-logBins-10bins-000000.chombo.h5'}
    },
    'Rhyme Log 30': {
        'mapPlot': False,
        0: {'path': './Iliev-5-CaseA-logBins-30bins/Iliev-5-CaseA-logBins-30bins-000000.chombo.h5'}
    },
    'Rhyme Log 100': {
        'mapPlot': True,
        0: {'path': './Iliev-5-CaseA-logBins-100bins/Iliev-5-CaseA-logBins-100bins-000000.chombo.h5'}
    },
}

# Rhyme_CaseA = {
#     0: {'path': f'./Iliev-5-CaseA-1000/Iliev-5-CaseA-000000.chombo.h5', },
# }

# Rhyme_CaseB = {
#     0: {'path': f'./Iliev-5-CaseB-1000/Iliev-5-CaseB-000000.chombo.h5', },
# }

# Check if files exist
def check_if_exist(path):
    if not Path(snap['path']).is_file():
        print(f"Not found: {snap['path']}")
        
for sim_name, sim in BINARIES.items():
    for snap in sim['orig'].values():
        check_if_exist(snap['path'])
    for snap in sim['add'].values():
        check_if_exist(snap['path'])

for sim in RHYME.values():
    check_if_exist(sim[0]['path'])
# for snap in Rhyme_CaseA.values():
#     check_if_exist(snap['path'])
# for snap in Rhyme_CaseB.values():
#     check_if_exist(snap['path'])

## Loading snapshots  <a class="anchor" id="loading"></a>

In [None]:
import os
def time_to_snap_id(t, ds):
    for i in range(0, ds.dataset.num_of_snapshots):
        ds.dataset.jump_to(i)
        snap_time = ds.dataset.time
        if snap_time > t:
            if i == 0:
                return i
            
            ds.dataset.jump_to(i-1)
            prev_time = ds.dataset.time
            
            return i if abs(t - snap_time) < abs(t - prev_time) else i - 1
        
    return ds.dataset.num_of_snapshots - 1


def load_rhyme(snap_id, rhyme, Gamma=5./3.):
    rhyme.dataset.jump_to(snap_id)
    res = {}

    v = rhyme.load_variables(silent=True)
    domain = rhyme.dataset.problem_domain

    rho = v['rho'][0].reshape(domain) * (1 * v['rho'][1]).to(U.cm**-3).value
    vx = v['rho_u'][0].reshape(domain) / rho * (1 * U.Mpc / U.Myr).to(U.cm / U.s).value
    vy = v['rho_v'][0].reshape(domain) / rho * (1 * U.Mpc / U.Myr).to(U.cm / U.s).value
    vz = v['rho_w'][0].reshape(domain) / rho * (1 * U.Mpc / U.Myr).to(U.cm / U.s).value
    T_orig = v['temp'][0].reshape(domain)
    fHI = v['ntr_frac_0'][0].reshape(domain)
    p, _, T = rhyme.calc_temperature(v, X=1.0, Y=0.0, Gamma=5./3.)
    T = T[0].reshape(domain)
    p = p.reshape(domain)
    p *= C.m_p.to(U.g).value * (1 * rhyme.dataset.active['h5']['attrs']['pressure_unit']).to(U.cm**-3 * U.cm**2 / U.s**2).value

    M = np.sqrt(vx**2 + vy**2 + vz**2) / np.sqrt(Gamma * p / (C.m_p.to(U.g).value * rho))

    res['n'] = rho
    res['nHI'] = fHI
    res['nHII'] = 1.0 - fHI
    res['T'] = T_orig
    res['p'] = p
    res['M'] = M
    
    return res


def reading_binary_files():
    result = {}
    
    for simname, sim in BINARIES.items():
        for snap_time, snap in sim['orig'].items():
            if snap_time not in result:
                result[snap_time] = {}
                
            if simname not in result[snap_time]:
                result[snap_time][simname] = {}
                
            f = FortranFile(snap['path'], 'r')
            domain = f.read_ints(np.int32)
            
            result[snap_time][simname]['nHI'] = f.read_reals(np.float32).reshape(domain)
            result[snap_time][simname]['p'] = f.read_reals(np.float32).reshape(domain)
            result[snap_time][simname]['T'] = f.read_reals(np.float32).reshape(domain)
            
            f.close()
            
        for snap_time, snap in sim['add'].items():
            if snap_time not in result:
                result[snap_time] = {}
                
            if simname not in result[snap_time]:
                result[snap_time][simname] = {}

            f = FortranFile(snap['path'], 'r')
            domain = f.read_ints(np.int32)
            
            result[snap_time][simname]['n'] = f.read_reals(np.float32).reshape(domain)
            result[snap_time][simname]['M'] = f.read_reals(np.float32).reshape(domain)
            result[snap_time][simname]['nHII'] = f.read_reals(np.float32).reshape(domain)
            
            f.close()
            
            
    for sn, sim in RHYME.items():
        for t in TIMES:
            r = PyRhyme(sim[0]['path'])
            result[t][sn] = load_rhyme(time_to_snap_id(t, r), r)
        
        r.dataset.close_current()
        r.dataset.clean_all()
        
        del r
        
#     rA = PyRhyme(Rhyme_CaseA[0]['path'])
#     rB = PyRhyme(Rhyme_CaseB[0]['path'])
    
#     for t in TIMES:
#         result[t]['Rhyme_CaseA'] = load_rhyme(time_to_snap_id(t, rA), rA)
#         result[t]['Rhyme_CaseB'] = load_rhyme(time_to_snap_id(t, rB), rB)

#     rA.dataset.close_current()
#     rA.dataset.clean_all()
    
#     rB.dataset.close_current()
#     rB.dataset.clean_all()
    
    return result
            
    
if __name__ == '__main__' and '__file__' not in globals():
    DATA = reading_binary_files()

In [None]:
def check_cell(attr='n', x=1, y=1, z=1):
    for t, snap in DATA.items():
        print('')
        for simname, sim in snap.items():
            print(f"{t:3.1f}: {simname:18s}: {attr} = {sim[attr][x, y, z]:3.2e}")

if __name__ == '__main__' and '__file__' not in globals():
    check_cell(attr='nHI', x=64)

## Maps <a class="anchor" id="maps"></a>

In [None]:
def plot_maps(x, y, heatmaps):
    titles = {
        'Rhyme Log 100': {'name': 'Rhyme (100 Log bins)', 'num': 0},
        'CAPREOLE+C2-RAY': {'name': 'C2Ray', 'num': 1},
        'Enzo': {'name': 'Enzo', 'num': 2},
        'Flash': {'name': 'Flash', 'num': 3},
        'HART': {'name': 'HART', 'num': 4},
        'Rhyme Log 3': {'name': 'Rhyme (3 Log bins)', 'num': 5},
        'LICORICE': {'name': 'Licorice', 'num': 6},
        'RSPH': {'name': 'RSPH', 'num': 7},
        'RH1D': {'name': 'RH1D', 'num': 8},
        'ZEUS-MP': {'name': "Zeus-MP", 'num': 9}
    }
    
    fig = make_subplots(
        4, 4,
        column_widths=[1] * 4, row_heights=[1] * 4,
        specs=[
            [{"type": "heatmap", "rowspan": 2, "colspan": 2}, None, {"type": "heatmap"}, {"type": "heatmap"}],
            [None, None, {"type": "heatmap"}, {"type": "heatmap"}],
            [{"type": "heatmap", "rowspan": 2, "colspan": 2}, None, {"type": "heatmap"}, {"type": "heatmap"}],
            [None, None, {"type": "heatmap"}, {"type": "heatmap"}],
        ],
        vertical_spacing=0.025, horizontal_spacing=0.025,
        shared_xaxes=True, shared_yaxes=True,
        x_title='x [px]', y_title='y [px]',
    )
    
    fig.layout.annotations[0]["font"] = dict(size=18, color='black')
    fig.layout.annotations[1]["font"] = dict(size=18, color='black')
    
#     attrs = ['nH', 'p', 'T', 'fHI']
    attrs = ['p', 'T', 'fHI']
    attr_titles = {'nH': 'log(n<sub>H</sub>)', 'p': 'log(p)', 'T': 'log(T)', 'fHI': 'log(f<sub>HI</sub>)'}
    attr_cs = {'nH': 'Spectral', 'p': 'Viridis', 'T': 'Inferno', 'fHI': 'Spectral'}
    attr_cs_rng = {'nH': [-4, -2], 'p': [-17, -14.5], 'T': [2, 4.5], 'fHI': [-6, 0]}
    n_slices = len(attrs)
    
    zmin, zmax = {}, {}
    for attr in attrs:
        zmin[attr] = np.min([np.min(xy) for sn in heatmaps.keys() for xy in heatmaps[sn][attr]])
        zmax[attr] = np.max([np.max(xy) for sn in heatmaps.keys() for xy in heatmaps[sn][attr]])
    
    for sn, h in heatmaps.items():
        si = titles[sn]['num']
        for slc, attr in enumerate(attrs):
            d1, d2 = slc * pi / 2 / n_slices, (slc + 1) * pi / 2 / n_slices
            
            xy = h[attr]
            
            for i in range(xy.shape[0]):
                for j in range(xy.shape[1]):
                    d = np.arctan2(j + 0.5, i + 0.5)
                    if not (d1 <= d <= d2):
                        xy[i][j] = None
                        
            fig.add_trace(go.Heatmap(
                x=x, y=y, z=xy, name=f"{sn}: {attr}",
                zauto=False, zmin=attr_cs_rng[attr][0], zmax=attr_cs_rng[attr][1],
                colorbar=dict(
                    title=attr_titles[attr], titleside='bottom', titlefont=dict(size=18),
                    thicknessmode='pixels', thickness=10, tickfont=dict(size=18),
                    x=1.02, y=slc/n_slices, len=1/n_slices, xanchor='left', yanchor='bottom',
                ),
                colorscale=attr_cs[attr], showscale=True if sn == 'Rhyme_CaseA' else False,
            ), h['row'], h['col'])
            
            if h['row'] == 4 or sn == 'Rhyme Log 3':
                fig.update_xaxes(
                    row=h['row'], col=['col'], mirror=True, titlefont=dict(size=18), tickfont=dict(size=14),
                    tickmode='array', tickvals=[0, 5, 10, 15], ticktext=['0', '5', '10', '15'],
                )
            else:
                fig.update_xaxes(row=h['row'], col=h['col'], mirror=True, showticklabels=False,)
            
            if h['col'] == 1:
                fig.update_yaxes(
                    row=h['row'], col=['col'], mirror=True, titlefont=dict(size=18), tickfont=dict(size=14),
                    tickmode='array', tickvals=[0, 5, 10, 15], ticktext=['0', '5', '10', '15'],
                )
            else:
                fig.update_yaxes(row=h['row'], col=h['col'], mirror=True, showticklabels=False,)
                
            l = 15 + 8.0/128
            if d2 <= pi / 4:
                xend = l
                yend = l * np.tan(d2)
                xannot = .9 * l * np.sin((d1 + d2)/2)
                yannot = .9 * l * np.cos((d1 + d2)/2)
            else:
                xend = l * np.tan(pi/2 - d2)
                yend = l
                xannot = .9 * l * np.cos(pi/2 - (d1+d2)/2)
                yannot = .9 * l * np.sin(pi/2 - (d1+d2)/2)
                
            if d2 < pi / 2:
                fig.add_shape(
                    type='line', xref='x' if si == 0 else f"x{si+1}", yref='y' if si == 0 else f"y{si+1}",
                    x0=0, y0=0, x1=xend, y1=yend,
                    line=dict(color="white", width=4 if h['col'] == 1 else 2.5), opacity=1.0,
                )
            
            textangle = 90 / (2 * n_slices) * (2 * slc + 1)
            fig.add_annotation(
                x=xannot, y=yannot, xref='x' if si == 0 else f"x{si+1}", yref='y' if si == 0 else f"y{si+1}",
                text=attr_titles[attr], showarrow=False,
                font=dict(color='white', size=24 if h['col'] == 1 else 16), textangle=textangle,
            )
            
        fig.add_annotation(
            x=14, y=0.25, xanchor='right', yanchor='bottom',
            xref='x' if si == 0 else f"x{si+1}", yref='y' if si == 0 else f"y{si+1}",
            text=titles[sn]['name'], showarrow=False,
            font=dict(color='white', size=22 if h['col'] == 1 else 14), 
        )
            

    width, height = 1100, 1000
    
    fig.update_layout(
        width=width, height=height,
        margin=dict(b=60, t=10, l=60, r=60),
    )
        
    return fig

In [None]:
def fHI_maps(t=200):
    if t not in DATA.keys():
        print(f"Don't have a snapshot at t = {t} Myr")
        return
    
    heatmaps = {}
    
    for sn, sim in DATA[t].items():
        if sn[:5] == 'Rhyme':
            if not RHYME[sn]['mapPlot']:
                continue
        heatmaps[sn] = {}
        if sn == 'Rhyme Log 100':
            heatmaps[sn]['row'] = 1
            heatmaps[sn]['col'] = 1
        elif sn == 'Rhyme Log 3':
            heatmaps[sn]['row'] = 3
            heatmaps[sn]['col'] = 1
        else:
            heatmaps[sn]['row'] = BINARIES[sn]['row']
            heatmaps[sn]['col'] = BINARIES[sn]['col']

        l = 2
        heatmaps[sn]['nH'] = np.log10(sim['nHI'][:, :, l] + sim['nHII'][:, :, l])
        heatmaps[sn]['p'] = np.log10(sim['p'][:, :, l])
        heatmaps[sn]['T'] = np.log10(sim['T'][:, :, l])
        heatmaps[sn]['fHI'] = np.log10(sim['nHI'][:, :, l] / (sim['nHI'][:, :, l] + sim['nHII'][:, :, l]))
    
    x = np.linspace(0, 15, 128)
        
    fig = plot_maps(x, x, heatmaps)
    
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    fHI_maps(t=100)

## Lineouts <a class="anchor" id="lineouts"></a>

In [None]:
def plot_lineouts(ts, attr, title):
    colors = px.colors.qualitative.Dark24 # Pastel
    
    len_t = len(ts)
    
    fig = make_subplots(
        1, len_t,
        column_widths=[1] * len_t, row_heights=[1],
        vertical_spacing=0.025, horizontal_spacing=0.025,
        shared_xaxes=False, shared_yaxes=True,
        x_title='x / L<sub>box</sub>', y_title=title,
        subplot_titles=[f"t = {t} Myr" for t in ts],
    )
    
    fig.layout.annotations[0]["font"] = dict(size=18, color='black')
    fig.layout.annotations[1]["font"] = dict(size=18, color='black')
    
    for i in range(2, 2+len_t):
        fig.layout.annotations[i]["font"] = dict(size=18, color='black')
    
    for i, t in enumerate(ts):
        for si, (sn, sim) in enumerate(DATA[t].items()):
            x = np.linspace(0, 1, len(sim['nHI'][:, 0, 0]))
            l = 1
            if attr == 'fHI':
                y = np.log10(sim['nHI'][:, l, l] / (sim['nHI'][:, l, l] + sim['nHII'][:, l, l]))
            else:
                y = np.log10(sim[attr][:, l, l])
                
            fig.add_trace(go.Scatter(
                x=x, y=y, name=f"{sn}", mode='lines',
                line=dict(color=colors[si], width=1.5),
                showlegend=True if i == 0 else False,
            ), 1, i+1)
            
            fig.update_xaxes(
                row=1, col=i+1, type='linear', mirror=True, range=(2/128, 1),
                titlefont=dict(size=18), tickfont=dict(size=14),
                tickmode='array', tickvals=[2/128, 0.5, 1], ticktext=['0', '.5', '1'],
            )
            
            tickvals = list(range(-40, 20, 2))
            ticktext = [f"10<sup>{x}</sup>" for x in tickvals]
            
            fig.update_yaxes(
                row=1, col=i+1, mirror=True,
                titlefont=dict(size=18), tickfont=dict(size=14),
                tickmode='array', tickvals=tickvals, ticktext=ticktext,
                showticklabels=True if i == 0 else False,
            )
            
    width = 500 if len_t <= 2 else 1000
    height = width / len_t + 50
    
    fig.update_layout(
        width=width, height=height,
        margin=dict(t=60, r=60, b=60, l=80),
    )
    
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_lineouts([10, 100, 200, 500], 'p', 'p [g / m s<sup>2</sup>]')
    plot_lineouts([10, 100, 200, 500], 'T', 'T [K]')
    plot_lineouts([10, 100, 200, 500], 'fHI', 'f<sub>HI</sub>')

## I-front position <a class="anchor" id="pos-vel"></a>

In [None]:
def i_front_get_position(xs, fHI, f0=.5):
    ydir = np.sign(np.diff(fHI))
    turning_points = 1 + np.where(np.diff(ydir) != 0)[0]
    
    signs = np.split(ydir, turning_points)
    xs_grp = np.split(xs, turning_points)
    fHI_grp = np.split(fHI, turning_points)
    fs = [
        (interp1d(y, x, bounds_error=False, fill_value=-1), sgn[0])
        for x, y, sgn in zip(xs_grp, fHI_grp, signs) if len(x) > 2
    ]
    
    return [(float(p[0]), p[1]) for p in [(f[0](f0), f[1]) for f in fs] if p[0] != -1]


def i_front_pos():
    xs = np.sqrt(3) * np.linspace(0, 15, 128)
    
    for t, snap in DATA.items():
        for sn, sim in DATA[t].items():
            if 'I-front' not in sim:
                sim['I-front'] = {}

            nHI_prof = np.einsum('iii->i', sim['nHI'])
            nHII_prof = np.einsum('iii->i', sim['nHII'])
            fHI_prof = nHII_prof / (nHI_prof + nHII_prof)

            pos = i_front_get_position(xs, fHI_prof)
            
            if len(pos) == 1:
                sim['I-front']['pos_kpc'] = pos[0][0]
            else:
                print(f"[Err] Could not extract the I-front position!")

if __name__ == '__main__' and '__file__' not in globals():
    i_front_pos()

In [None]:
def i_front_pos_plot():
    colors = px.colors.qualitative.Dark24 # Pastel
    
    fig = go.Figure()
    
    times = list(DATA.keys())
    for i, sn in enumerate(DATA[10].keys()):
        ys = [DATA[t][sn]['I-front']['pos_kpc'] for t in times]
        
        fig.add_trace(go.Scatter(
            x=times, y=ys, name=sn,
            mode='lines+markers',
            line=dict(color=colors[i]),
        ))
        
#     def approx_rI(t, rS=5.4 * U.kpc, T=1e4 * U.K, gamma=5./3):
#         cs = np.sqrt(gamma * C.k_B / C.m_p * T)
#         return (rS * (1 + (7 * cs * (t * U.Myr)) / (4 * rS))).to(U.kpc)
        
#     rI = [approx_rI(t).value for t in DATA.keys()]

#     fig.add_trace(go.Scatter(
#         x=list(DATA.keys()), y=rI, name='Approx. r<sub>I</sub>',
#         mode='lines', line=dict(color='black', width=1.5)
#     ))
            
    width, height = 500 + 150, 500
            
    fig.update_layout(
        margin=dict(b=10, t=10, l=10, r=10),
        width=width, height=height,
        xaxis=dict(
            type='log', mirror=True, title='<i>t</i> [Myr]',
            titlefont=dict(size=18),
            tickmode='array',
            tickvals=TIMES,
            ticktext=[str(tt) for tt in TIMES],
            tickfont=dict(size=14),
        ),
        yaxis=dict(
            type='linear', mirror=True, title='<i>x</i><sub>I-front</sub> [kpc]',
            titlefont=dict(size=18),
            tickmode='array',
            tickvals=list(range(2, 11, 2)),
            ticktext=[str(yy) for yy in range(2, 11, 2)],
            tickfont=dict(size=14),
        ),
    )
            
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    i_front_pos_plot()