In [None]:
from pyrhyme import PyRhyme

from pathlib import Path

# !pip install numpy
import numpy as np

# !pip install astropy
from astropy import units as U
from astropy import constants as C
from astropy.cosmology import WMAP7

# !pip install scipy
from scipy.io import FortranFile

# !pip install plotly
import plotly
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
from plotly.subplots import make_subplots
pio.templates.default = "simple_white"

In [None]:
SNAPSHOTS = {
    'CaseA': {
        'path': './Iliev-6-CaseA-PowerLaw/Iliev-6-CaseA-000000.chombo.h5',
    },
    'CaseB': {
        'path': 'Iliev-6-CaseB-PowerLaw/Iliev-6-CaseB-000000.chombo.h5',
    },
}

for case in SNAPSHOTS.values():
    if not Path(case['path']).is_file():
        print(f"File not found: {case['path']}")

In [None]:
def time_to_snap_id(t, ds):
    for i in range(0, ds.dataset.num_of_snapshots):
        ds.dataset.jump_to(i)
        snap_time = ds.dataset.time
        if snap_time > t:
            if i == 0:
                return i
            
            ds.dataset.jump_to(i-1)
            prev_time = ds.dataset.time
            
            return i if abs(t - snap_time) < abs(t - prev_time) else i - 1
        
    return ds.dataset.num_of_snapshots - 1


def compare(t=1, attr='fHI', attr_label='fHI'):
    if attr not in ['rho', 'vx', 'vy', 'vz', 'T', 'fHI', 'p']:
        print(f'Unknown attribute: {attr}')
        return
    
    fig = go.Figure()
    
    ds, data = {}, {}
    for case_name, case in SNAPSHOTS.items():
        if t in SNAPSHOTS[case_name].keys():
            continue
        else:
            SNAPSHOTS[case_name][t] = {}
            
        ds[case_name] = PyRhyme(case['path'])
        
        i = time_to_snap_id(t, ds[case_name])
        ds[case_name].dataset.jump_to(i)
        
        domain = ds[case_name].dataset.problem_domain
        
        v = ds[case_name].load_variables(silent=True)
        rho = v['rho'][0].reshape(domain) * (1 * v['rho'][1]).to(U.cm**-3).value
        vx = v['rho_u'][0].reshape(domain) / rho * (1 * U.Mpc / U.Myr).to(U.cm / U.s).value
        vy = v['rho_v'][0].reshape(domain) / rho * (1 * U.Mpc / U.Myr).to(U.cm / U.s).value
        vz = v['rho_w'][0].reshape(domain) / rho * (1 * U.Mpc / U.Myr).to(U.cm / U.s).value
        T_orig = v['temp'][0].reshape(domain)
        fHI = v['ntr_frac_0'][0].reshape(domain)
        p, _, T = ds[case_name].calc_temperature(v, X=1.0, Y=0.0, Gamma=5./3.)
        T = T[0].reshape(domain)
        p = p.reshape(domain)
        p *= C.m_p.to(U.g).value * (1 * ds[case_name].dataset.active['h5']['attrs']['pressure_unit']).to(U.cm**-3 * U.cm**2 / U.s**2).value
        
        SNAPSHOTS[case_name][t]['rho'] = rho
        SNAPSHOTS[case_name][t]['vx'] = vx
        SNAPSHOTS[case_name][t]['vy'] = vy
        SNAPSHOTS[case_name][t]['vz'] = vz
        SNAPSHOTS[case_name][t]['T'] = T_orig
        SNAPSHOTS[case_name][t]['fHI'] = fHI
        SNAPSHOTS[case_name][t]['p'] = p
        
    for case_name, case in SNAPSHOTS.items():
        shape = SNAPSHOTS[case_name][t][attr].shape
        fig.add_trace(go.Scatter(
            y=np.einsum('ii->i', SNAPSHOTS[case_name][t][attr][:, :, 0]),
            name=case_name,
            mode='lines', line=dict(width=2.5),
        ))
        
    width = 500
    height = 500
        
    fig.update_layout(
        width=width, height=height,
        margin=dict(l=10,r=10,b=10,t=10),
        xaxis=dict(mirror=True, type='linear', title='x [px]', exponentformat='power',),
        yaxis=dict(mirror=True, type='log', title=attr_label, exponentformat='power',),
        legend=dict(
            xanchor='left', yanchor='top', x=0.04, y=0.99
        ),
        showlegend=True,
    )
        
    fig.show()
    

if __name__ == '__main__' and '__file__' not in globals():
    compare(t=100, attr='fHI', attr_label='fHI')