In [None]:
from pyrhyme import PyRhyme

from pathlib import Path

# !pip install numpy
import numpy as np

# !pip install astropy
from astropy import units as U
from astropy import constants as C

# !pip install plotly
import plotly
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
from plotly.subplots import make_subplots
pio.templates.default = "simple_white"

from math import pi

import csv

import os

FIGURE_DIR = './figures'
Path(FIGURE_DIR).mkdir(parents=True, exist_ok=True)

def time_to_snap_id(t, ds):
    for i in range(0, ds.dataset.num_of_snapshots):
        ds.dataset.jump_to(i)
        snap_time = ds.dataset.time
        if snap_time > t:
            if i == 0:
                return i
            
            ds.dataset.jump_to(i-1)
            prev_time = ds.dataset.time
            
            return i if abs(t - snap_time) < abs(t - prev_time) else i - 1
        
    return ds.dataset.num_of_snapshots - 1


def load_rhyme(snap_id, rhyme, n_dim, gamma=7./5.):
    snap = {}
    
    rhyme.dataset.jump_to(snap_id)
    domain = rhyme.dataset.problem_domain
    
    v = rhyme.load_variables(silent=True)
    
    r_unit = rhyme.dataset.active['h5']['attrs']['density_unit']
    v_unit = rhyme.dataset.active['h5']['attrs']['velocity_unit']
    p_unit = rhyme.dataset.active['h5']['attrs']['pressure_unit']
    e_unit = p_unit * rhyme.dataset.active['h5']['attrs']['length_unit']**3

    snap['rho'] = (v['rho'][0].reshape(domain).squeeze() * r_unit).to(U.kg / U.m**3).value
    snap['vx'] = (v['rho_u'][0].reshape(domain).squeeze() / snap['rho'] * v_unit).to(U.m / U.s).value
    if n_dim > 1:
        snap['vy'] = (v['rho_v'][0].reshape(domain).squeeze() / snap['rho'] * v_unit).to(U.m / U.s).value
    else:
        snap['vy'] = 0 if 'rho_v' in v else np.zeros(domain).squeeze()
    if n_dim > 2:
        snap['vz'] = (v['rho_w'][0].reshape(domain).squeeze() / snap['rho'] * v_unit).to(U.m / U.s).value if 'rho_w' in v else np.zeros(domain).squeeze()
    else:
        snap['vz'] = 0
    snap['e_tot'] = (v['e_tot'][0].reshape(domain).squeeze() * e_unit).to(U.kg * U.m**2 / U.s**2).value
    
    snap['v2'] = snap['vx']**2 + snap['vy']**2 + snap['vz']**2
    snap['|v|'] = np.sqrt(snap['v2'])
    snap['p'] = (snap['e_tot'] - .5 * snap['rho'] * snap['v2']) * (gamma - 1)
    snap['M'] = snap['|v|'] / np.sqrt(gamma * snap['p'] / snap['rho'])

    return snap

In [None]:
SNAPSHOTS = {
    'linear': {
        'convergence': {
            d: {
                'path': f"convergence_test/sedov_explosion_1d_{d:05d}/sedov_explosion_1d_{d:05d}/sedov_explosion_1d_{d:05d}-000001.chombo.h5"
            } for d in [2**x for x in range(5, 17)]
        },
        '1d': {
            d: {
                'path': f"sedov_explosion_1d_{d:05d}/sedov_explosion_1d_{d:05d}/sedov_explosion_1d_{d:05d}-000001.chombo.h5"
            } for d in [4096]
        },
    },
    'spherical': {
    },
}

# Check if snapshots are accessible
for kind in SNAPSHOTS.values():
    for dim, snaps in kind.items():
        for res, snap in snaps.items():
            if not Path(snap['path']).is_file():
                print(f"[err] Cannot access {res} snap @ {snap['path']}")
            else:
                print(f"[accessible] {res} snap @ {snap['path']}")

In [None]:
def load_snapshot(times=[0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.035, 0.04, 0.045, 0.05]):
    result = {}
    
    for kind, dims in SNAPSHOTS.items():
        if not dims:
            continue
            
        result[kind] = {}
        
        for dim, snaps in dims.items():
            result[kind][dim] = {}
            
            for res, snap in snaps.items():
                result[kind][dim][res] = {}
                i = PyRhyme(os.path.dirname(snap['path']) + '/IC-000000.chombo.h5')
                d = int(dim[:1]) if dim != 'convergence' else 1
                result[kind][dim][res][0] = load_rhyme(time_to_snap_id(0, i), i, d)
                
                r = PyRhyme(snap['path'])
                for t in times:
                    result[kind][dim][res][t] = load_rhyme(time_to_snap_id(t, r), r, d)
        
        i.dataset.close_current()
        i.dataset.clean_all()
        r.dataset.close_current()
        r.dataset.clean_all()
        
        del i, r
    
    return result

if __name__ == '__main__' and '__file__' not in globals():
    DATA = load_snapshot()

In [None]:
def load_analytical_solution():
    prefix = 'analytical_solutions'
    result = {}
    
    for t, path in [(0.01, f'{prefix}/sedov_01.csv'), (0.03, f'{prefix}/sedov_03.csv'), (0.05, f'{prefix}/sedov_05.csv')]:
        x, rho, vx, p = [], [], [], []
        
        with open(path, 'r') as file:
            reader = csv.reader(file)
            next(reader, None)
            for row in reader:
                x.append(float(row[0]))
                rho.append(float(row[1]))
                vx.append(float(row[2]))
                p.append(float(row[3]))
                
        result[t] = {'x': np.array(x), 'rho': np.array(rho), 'vx': np.array(vx), 'p': np.array(p)}
        
    return result

if __name__ == '__main__' and '__file__' not in globals():
    SOLUTION = load_analytical_solution()

In [None]:
def plot_convergence_lineouts(
    times=[0.01, 0.03, 0.05],
    attrs=['rho', 'vx', 'p'],
    attr_labels=['ρ', 'v<sub>x</sub>', 'p'],
    norm_res_fac=32,
):
    colors = px.colors.sample_colorscale(px.colors.diverging.Portland_r, np.linspace(0, 1, len(DATA['linear']['convergence'].keys())+1))
    
    ticks = {
        'rho': {
            'type': 'log',
            'range': (-1, 1),
            'vals': [10**x for x in range(-3, 2)],
            'text': [f"10<sup>{x}</sup>" if x != 0 else '1' for x in range(-3, 2)],
        },
        'p': {
            'type': 'log',
            'range': (-5.5, -0.5),
            'vals': [10**x for x in range(-6, 2)],
            'text': [f"10<sup>{x}</sup>" if x != 0 else '1' for x in range(-6, 2)],
        },
        'vx': {
            'type': 'linear',
            'range': (-0.02, 0.45),
            'vals': [0, 0.1, 0.2, 0.3, 0.4],
            'text': [f"{x}" for x in [0, 0.1, 0.2, 0.3]],
        },
    }
    
    xranges = {
        0.01: (0, 5),
        0.03: (0, 10),
        0.05: (0, 15),
    }
    
    fig = make_subplots(
        len(attrs), len(times),
        row_heights=[1]*len(attrs), column_widths=[1]*len(times),
        vertical_spacing=0.025, horizontal_spacing=0.025,
        shared_xaxes=True, shared_yaxes=False,
        x_title='x [corsest cell]',
        subplot_titles=[f"t = {t}" for t in times],
    )
    
    for i in range(len(times)):
        fig.layout.annotations[i]["font"] = dict(size=18, color='black')
        fig.layout.annotations[i]["y"] += 0.0125

    fig.layout.annotations[len(times)]["font"] = dict(size=18, color='black')
    fig.layout.annotations[len(times)]["y"] -= 0.0125
    
    for irow, attr in enumerate(attrs):
        for icol, t in enumerate(times):
            for ci, (res, snaps) in enumerate(DATA['linear']['convergence'].items()):
                y = np.array(DATA['linear']['convergence'][res][t][attr])
                x = np.linspace(0, norm_res_fac, len(y)+1)
                x -= 0.0625 / res  # FIXME
                
                fig.add_trace(go.Scatter(
                    x=x, y=y, name=f"{res}",
                    mode='lines', line=dict(color=colors[ci], width=1),
                    showlegend=True if irow == icol == 0 else False
                ), irow+1, icol+1)
                
            fig.update_xaxes(
                row=irow+1, col=icol+1, mirror=True, type='linear', range=xranges[t],
#                 tickmode='array', tickvals=[1, 24, 48], ticktext=['0', '24', '48'],
                titlefont=dict(size=18), tickfont=dict(size=14),
            )
            
            fig.update_yaxes(
                title=attr_labels[irow] if icol == 0 else None, type=ticks[attr]['type'], range=ticks[attr]['range'],
                row=irow+1, col=icol+1, mirror=True,
                tickmode='array', tickvals=ticks[attr]['vals'], ticktext=ticks[attr]['text'], showticklabels=True if icol == 0 else False,
                titlefont=dict(size=18), tickfont=dict(size=14),
            )
        
    width, height = 1000, 1000
    
    fig.update_layout(
        width=width, height=height,
        margin=dict(t=10, b=70, l=60, r=10),
        legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=.5),
    )

    fig.write_image(f"{FIGURE_DIR}/sedov_explosion_1d_convergence.svg", width=width, height=height)
    fig.write_image(f"{FIGURE_DIR}/sedov_explosion_1d_convergence.png", width=width, height=height, scale=3)
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_convergence_lineouts()

In [None]:
def plot_lineouts(
    times=[0.01, 0.03, 0.05],
    attrs=['rho', 'vx', 'p'],
    attr_labels=['ρ', 'v<sub>x</sub>', 'p'],
    norm_res_fac=4096,
):
    colors = px.colors.sample_colorscale(px.colors.diverging.Portland, np.linspace(0, 1, len(DATA['linear']['1d'].keys())+1))
    
    ticks = {
        'rho': {
            'type': 'log',
            'range': (-2, 1),
            'vals': [10**x for x in range(-3, 2)],
            'text': [f"10<sup>{x}</sup>" if x != 0 else '1' for x in range(-3, 2)],
        },
        'p': {
            'type': 'log',
            'range': (-5.5, -1.9),
            'vals': [10**x for x in range(-6, 2)],
            'text': [f"10<sup>{x}</sup>" if x != 0 else '1' for x in range(-6, 2)],
        },
        'vx': {
            'type': 'linear',
            'range': (-0.02, 0.11),
            'vals': [0, 0.1, 0.2, 0.3, 0.4],
            'text': [f"{x}" for x in [0, 0.1, 0.2, 0.3]],
        },
    }
    
    xranges = {
        0.01: (0, 150),
        0.03: (0, 250),
        0.05: (0, 350),
    }
    
    fig = make_subplots(
        len(attrs), len(times),
        row_heights=[1]*len(attrs), column_widths=[1]*len(times),
        vertical_spacing=0.025, horizontal_spacing=0.025,
        shared_xaxes=True, shared_yaxes=False,
        x_title='x [cells]',
        subplot_titles=[f"t = {t}" for t in times],
    )
    
    for i in range(len(times)):
        fig.layout.annotations[i]["font"] = dict(size=18, color='black')
        fig.layout.annotations[i]["y"] += 0.0125

    fig.layout.annotations[len(times)]["font"] = dict(size=18, color='black')
    fig.layout.annotations[len(times)]["y"] -= 0.0125
    
    for irow, attr in enumerate(attrs):
        for icol, t in enumerate(times):
#             fig.add_trace(go.Scatter(
#                 x=SOLUTION[t]['x'][:2000], y=(SOLUTION[t][attr][:2000]), name='Sol.',
#                 mode='lines',  line=dict(color='black', width=1),
#                 showlegend=True if irow == icol == 0 else False
#             ), irow+1, icol+1)
            
            for ci, (res, snaps) in enumerate(DATA['linear']['1d'].items()):
                y = np.array(DATA['linear']['1d'][res][t][attr])
                x = np.linspace(0, norm_res_fac, len(y)+1)
                x -= 0.0625 / res  # FIXME
                
                fig.add_trace(go.Scatter(
                    x=x, y=y, name=f"{res}",
                    mode='lines', line=dict(color=colors[ci], width=1),
                    showlegend=True if irow == icol == 0 else False
                ), irow+1, icol+1)
                
            fig.update_xaxes(
                row=irow+1, col=icol+1, mirror=True, type='linear', range=xranges[t],
#                 tickmode='array', tickvals=[1, 24, 48], ticktext=['0', '24', '48'],
                titlefont=dict(size=18), tickfont=dict(size=14),
            )
            
            fig.update_yaxes(
                title=attr_labels[irow] if icol == 0 else None, type=ticks[attr]['type'], range=ticks[attr]['range'],
                row=irow+1, col=icol+1, mirror=True,
                tickmode='array', tickvals=ticks[attr]['vals'], ticktext=ticks[attr]['text'], showticklabels=True if icol == 0 else False,
                titlefont=dict(size=18), tickfont=dict(size=14),
            )
        
    width, height = 1000, 1000
    
    fig.update_layout(
        width=width, height=height,
        margin=dict(t=10, b=70, l=60, r=10),
        legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=.5),
    )

    fig.write_image(f"{FIGURE_DIR}/sedov_explosion_1d.svg", width=width, height=height)
    fig.write_image(f"{FIGURE_DIR}/sedov_explosion_1d.png", width=width, height=height, scale=3)
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_lineouts()