# Planar Sedov's explosion

## TOC
* [Packages & Settings](#packages)
* [Snapshots](#snapshots)
* [Loading Snapshots](#loading)
* [Convergence Study](#convergence)
* [Analytical Solution](#analytical)
* [Comparison to Analytical Solution](#comparison)

## Packages & Settings  <a class="anchor" id="packages"></a>

In [None]:
from pyrhyme import PyRhyme

import os
from pathlib import Path

# !pip install numpy
import numpy as np

# !pip install astropy
from astropy import units as U
from astropy import constants as C

# !pip3 install -U plotly
import plotly
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
plotly.io.templates.default = "simple_white"

from math import pi
from scipy.special import gamma as Γ

FIGURE_DIR = './figures'
Path(FIGURE_DIR).mkdir(parents=True, exist_ok=True)

def time_to_snap_id(t, ds):
    for i in range(0, ds.dataset.num_of_snapshots):
        ds.dataset.jump_to(i)
        snap_time = ds.dataset.time
        if snap_time > t:
            if i == 0:
                return i
            
            ds.dataset.jump_to(i-1)
            prev_time = ds.dataset.time
            
            return i if abs(t - snap_time) < abs(t - prev_time) else i - 1
        
    return ds.dataset.num_of_snapshots - 1


def load_rhyme(snap_id, rhyme, n_dim, g=7./5.):
    snap = {}
    
    rhyme.dataset.jump_to(snap_id)
    domain = rhyme.dataset.problem_domain
    
    v = rhyme.load_variables(silent=True)
    
    r_unit = rhyme.dataset.active['h5']['attrs']['density_unit']
    v_unit = rhyme.dataset.active['h5']['attrs']['velocity_unit']
    p_unit = rhyme.dataset.active['h5']['attrs']['pressure_unit']
    e_unit = p_unit * rhyme.dataset.active['h5']['attrs']['length_unit']**3

    snap['rho'] = (v['rho'][0].reshape(domain).squeeze() * r_unit).to(U.kg / U.m**3).value
    snap['vx'] = (v['rho_u'][0].reshape(domain).squeeze() / snap['rho'] * v_unit).to(U.m / U.s).value
    if n_dim > 1:
        snap['vy'] = (v['rho_v'][0].reshape(domain).squeeze() / snap['rho'] * v_unit).to(U.m / U.s).value
    else:
        snap['vy'] = 0 if 'rho_v' in v else np.zeros(domain).squeeze()
    if n_dim > 2:
        snap['vz'] = (v['rho_w'][0].reshape(domain).squeeze() / snap['rho'] * v_unit).to(U.m / U.s).value if 'rho_w' in v else np.zeros(domain).squeeze()
    else:
        snap['vz'] = 0
    snap['e_tot'] = (v['e_tot'][0].reshape(domain).squeeze() * e_unit).to(U.kg * U.m**2 / U.s**2).value
    
    snap['v2'] = snap['vx']**2 + snap['vy']**2 + snap['vz']**2
    snap['|v|'] = np.sqrt(snap['v2'])
    snap['p'] = (snap['e_tot'] - .5 * snap['rho'] * snap['v2']) * (g - 1)
    snap['M'] = snap['|v|'] / np.sqrt(g * snap['p'] / snap['rho'])

    return snap

def make_func():
    func = dict()
    for i, t in enumerate(['0.01', '0.03', '0.05']):
        variables = ['r', 'rho', 'p', 'vx']
        df = pd.read_csv("./data/t_%s" % t, delimiter='\t', names=['r', 'rho', 'p', 'vx'])
        func[t] = dict()

        for j, var in enumerate(variables[1:]):
            func[t][var] = scipy.interpolate.interp1d(df['r'], df[var])

    return func

## Snapshots  <a class="anchor" id="snapshots"></a>

In [None]:
SNAPSHOTS = {
    'linear': {
        'convergence': {
            d: {
                'path': f"convergence_test/sedov_explosion_1d_{d:05d}/sedov_explosion_1d_{d:05d}/sedov_explosion_1d_{d:05d}-000001.chombo.h5"
            } for d in [2**x for x in range(5, 17)]
        },
        '1d': {
            d: {
                'path': f"sedov_explosion_1d_{d:05d}/sedov_explosion_1d_{d:05d}/sedov_explosion_1d_{d:05d}-000001.chombo.h5"
            } for d in [4096]
        },
    },
    'spherical': {
    },
}

# Check if snapshots are accessible
for kind in SNAPSHOTS.values():
    for dim, snaps in kind.items():
        for res, snap in snaps.items():
            if not Path(snap['path']).is_file():
                print(f"[err] Cannot access {res} snap @ {snap['path']}")
            else:
                print(f"[accessible] {res} snap @ {snap['path']}")

## Loading snapshots  <a class="anchor" id="loading"></a>

In [None]:
def load_snapshot(times=[0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.035, 0.04, 0.045, 0.05]):
    result = {}
    
    for kind, dims in SNAPSHOTS.items():
        if not dims:
            continue
            
        result[kind] = {}
        
        for dim, snaps in dims.items():
            result[kind][dim] = {}
            
            for res, snap in snaps.items():
                result[kind][dim][res] = {}
                i = PyRhyme(os.path.dirname(snap['path']) + '/IC-000000.chombo.h5')
                d = int(dim[:1]) if dim != 'convergence' else 1
                result[kind][dim][res][0] = load_rhyme(time_to_snap_id(0, i), i, d)
                
                r = PyRhyme(snap['path'])
                for t in times:
                    result[kind][dim][res][t] = load_rhyme(time_to_snap_id(t, r), r, d)
        
        i.dataset.close_current()
        i.dataset.clean_all()
        r.dataset.close_current()
        r.dataset.clean_all()
        
        del i, r
    
    return result

if __name__ == '__main__' and '__file__' not in globals():
    DATA = load_snapshot()

## Convergence Study  <a class="anchor" id="convergence"></a>

In [None]:
def plot_convergence_lineouts(
    times=[0.01, 0.03, 0.05],
    attrs=['rho', 'vx', 'p'],
    attr_labels=['ρ', 'v<sub>x</sub>', 'p'],
    base_res=32,
):
    colors = px.colors.sample_colorscale(px.colors.diverging.Portland_r, np.linspace(0, 1, len(DATA['linear']['convergence'].keys())+1))
    
    ticks = {
        'rho': {
            'type': 'log',
            'range': (-1, 1),
            'vals': [10**x for x in range(-3, 2)],
            'text': [f"10<sup>{x}</sup>" if x != 0 else '1' for x in range(-3, 2)],
        },
        'p': {
            'type': 'log',
            'range': (-5.5, -0.5),
            'vals': [10**x for x in range(-6, 2)],
            'text': [f"10<sup>{x}</sup>" if x != 0 else '1' for x in range(-6, 2)],
        },
        'vx': {
            'type': 'linear',
            'range': (-0.02, 0.45),
            'vals': [0, 0.1, 0.2, 0.3, 0.4],
            'text': [f"{x}" for x in [0, 0.1, 0.2, 0.3]],
        },
    }
    
    xranges = {
        0.01: (0, 5),
        0.03: (0, 10),
        0.05: (0, 15),
    }
    
    fig = make_subplots(
        len(attrs), len(times),
        row_heights=[1]*len(attrs), column_widths=[1]*len(times),
        vertical_spacing=0.025, horizontal_spacing=0.025,
        shared_xaxes=True, shared_yaxes=False,
        x_title='x [corsest cell]',
        subplot_titles=[f"t = {t}" for t in times],
    )
    
    for i in range(len(times)):
        fig.layout.annotations[i]["font"] = dict(size=18, color='black')
        fig.layout.annotations[i]["y"] += 0.0125

    fig.layout.annotations[len(times)]["font"] = dict(size=18, color='black')
    fig.layout.annotations[len(times)]["y"] -= 0.0125
    
    for irow, attr in enumerate(attrs):
        for icol, t in enumerate(times):
            for ci, (res, snaps) in enumerate(DATA['linear']['convergence'].items()):
                y = np.array(DATA['linear']['convergence'][res][t][attr])
                x = np.linspace(0, base_res, len(y)+1)
                x -= 0.0625 / res  # FIXME
                
                fig.add_trace(go.Scatter(
                    x=x, y=y, name=f"{res}",
                    mode='lines', line=dict(color=colors[ci], width=1),
                    showlegend=True if irow == icol == 0 else False
                ), irow+1, icol+1)
                
            fig.update_xaxes(
                row=irow+1, col=icol+1, mirror=True, type='linear', range=xranges[t],
#                 tickmode='array', tickvals=[1, 24, 48], ticktext=['0', '24', '48'],
                titlefont=dict(size=18), tickfont=dict(size=14),
            )
            
            fig.update_yaxes(
                title=attr_labels[irow] if icol == 0 else None, type=ticks[attr]['type'], range=ticks[attr]['range'],
                row=irow+1, col=icol+1, mirror=True,
                tickmode='array', tickvals=ticks[attr]['vals'], ticktext=ticks[attr]['text'], showticklabels=True if icol == 0 else False,
                titlefont=dict(size=18), tickfont=dict(size=14),
            )
        
    width, height = 1000, 1000
    
    fig.update_layout(
        width=width, height=height,
        margin=dict(t=10, b=70, l=60, r=10),
        legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=.5),
    )

    fig.write_image(f"{FIGURE_DIR}/sedov_explosion_1d_convergence.svg", width=width, height=height)
    fig.write_image(f"{FIGURE_DIR}/sedov_explosion_1d_convergence.png", width=width, height=height, scale=3)
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_convergence_lineouts()

## Analytical Solution  <a class="anchor" id="analytical"></a>

Based on [Peter Creasey](https://github.com/Lowingbn/iccpy/blob/8a173c63440fa8697e49d537f1e8792b8037a3d6/flash/sedov_analytic.py) code

In [None]:
def sedov(t, E_0=1, rho_0=1, g =1.4, n=1000, nu=1, res=4096):
    gp1, gm1, gm2 = g + 1, g - 1, g - 2
    gp1_gm1 = gp1 / gm1
    nugm1 = nu * gm1
    nu2 = nu + 2
    
    vrange = (2 / nu2 / g, 4 / nu2 / gp1)
    dv = (vrange[1] - vrange[0]) / (n - 1)
    v = vrange[0] + np.arange(n) * dv
   
    a0 = 2 / nu2
    a2 = -gm1 / (2 * gm1 + nu)
    a1 = nu2 * g / (2 + nugm1) * (-2 * nu * gm2 / (g * nu2**2) - a2)
    a3 = nu / (2 * gm1 + nu)
    a4 = -a1 * nu2 / gm2
    a5 = 2 / gm2
    a6 = g / (2 * gm1 + nu)
    a7 =  -a1 * (2 + nugm1) / (nu * gm2)
    
    b = nu2 * gp1 * np.array([
        0.25,
        (g / gm1) / 2,
        -(2 + nugm1) / 2 / (nu2 * gp1 - 2 * (2 + nugm1)),
        -1 / (2 * gm1),
    ])

    beta = np.outer(b, v) + gp1 * np.array([
        0.0,
        -1 / gm1,
        nu2 / (nu2 * gp1 - 2 * (2 + nugm1)),
        1 / gm1,
    ]).reshape((4,1))
    
    lb = np.log(beta)
    
    r = np.exp(-a0 * lb[0] - a2 * lb[1] - a1 * lb[2])
    
    rho = gp1_gm1 * np.exp(a3 * lb[1] + a5 * lb[3] + a4 * lb[2])
    p = np.exp(nu * a0 * lb[0] + (a5 + 1) * lb[3] + (a4 - 2 * a1) * lb[2]) * 8 / (gp1 * nu2**2)
    u = beta[0] * 4 * r / (gp1 * nu2)
    
    r[0], rho[0], u[0], p[0] = 0, 0, 0, p[1]

    V = np.power(r, nu) * (pi**(nu / 2) / Γ(nu / 2 + 1))
    de = rho * u**2 * 0.5 + p / gm1
    q = np.inner((de[1:] + de[:-1]), np.diff(V)) / 2  

    f = (q * (t**nu) * rho_0 / E_0)**(-1 / nu2)

    return r * f * t * res / 2, rho * rho_0, u * f * 0.03, p * f**2 * rho_0 * 0.001

## Comparison to Analytical Solution  <a class="anchor" id="comparison"></a>

In [None]:
def plot_lineouts(
    times=[0.01, 0.03, 0.05],
    attrs=['rho', 'vx', 'p'],
    attr_labels=['ρ', 'v<sub>x</sub>', 'p'],
    res=4096,
):
    
    sol = {0.01: {}, 0.03: {}, 0.05:{}}
    sol[0.01]['r'], sol[0.01]['rho'], sol[0.01]['vx'], sol[0.01]['p'] = sedov(0.01, 1, 1, 1.4, n=1000, nu=1, res=res)
    sol[0.03]['r'], sol[0.03]['rho'], sol[0.03]['vx'], sol[0.03]['p'] = sedov(0.03, 1, 1, 1.4, n=1000, nu=1, res=res)
    sol[0.05]['r'], sol[0.05]['rho'], sol[0.05]['vx'], sol[0.05]['p'] = sedov(0.05, 1, 1, 1.4, n=1000, nu=1, res=res)

    settings = {
        'rho': {
            'init': 1,
            'type': 'log', 'range': (-1.5, 1),
            'vals': [10**x for x in range(-3, 2)],
            'text': [f"10<sup>{x}</sup>" if x != 0 else '1' for x in range(-3, 2)],
        },
        'p': {
            'init': 1e-5,
            'type': 'log', 'range': (-5.5, -1.9),
            'vals': [10**x for x in range(-6, 2)],
            'text': [f"10<sup>{x}</sup>" if x != 0 else '1' for x in range(-6, 2)],
        },
        'vx': {
            'init': 0,
            'type': 'linear', 'range': (-0.01, 0.08),
            'vals': [0, 0.02, 0.04, 0.06, 0.08],
            'text': [f"{x}" for x in [0, 0.02, 0.04, 0.06, 0.08]],
        },
    }
    
    xranges = {
        0.01: (0, 150),
        0.03: (0, 250),
        0.05: (0, 350),
    }
    
    fig = make_subplots(
        len(attrs), len(times),
        row_heights=[1]*len(attrs), column_widths=[1]*len(times),
        vertical_spacing=0.025, horizontal_spacing=0.025,
        shared_xaxes=True, shared_yaxes=False,
        x_title='x [cells]',
        subplot_titles=[f"t = {t}" for t in times],
    )
    
    for i in range(len(times)):
        fig.layout.annotations[i]["font"] = dict(size=18, color='black')
        fig.layout.annotations[i]["y"] += 0.0125

    fig.layout.annotations[len(times)]["font"] = dict(size=18, color='black')
    fig.layout.annotations[len(times)]["y"] -= 0.0125

    for irow, attr in enumerate(attrs):
        for icol, t in enumerate(times):
            
            solx = np.array(sol[t]['r'])
            soly = np.array(sol[t][attr])
            
            solxend = solx[-1]
            soly = [x for x in soly] + [settings[attr]['init'], settings[attr]['init']]
            solx = [x for x in solx] + [solxend, 500]
            
            fig.add_trace(go.Scatter(
                x=solx, y=soly, name='Sol.',
                mode='lines',  line=dict(color='black', width=1),
                showlegend=True if irow == icol == 0 else False
            ), irow+1, icol+1)
            
            for ci, (res, snaps) in enumerate(DATA['linear']['1d'].items()):
                y = np.array(DATA['linear']['1d'][res][t][attr])
                x = np.linspace(0, res, len(y)+1)
                x -= 0.0625 / res  # FIXME
                
                fig.add_trace(go.Scatter(
                    x=x, y=y, name=f"{res}",
                    mode='lines', line=dict(color='green', width=1),
                    showlegend=True if irow == icol == 0 else False
                ), irow+1, icol+1)
                
            fig.update_xaxes(
                row=irow+1, col=icol+1, mirror=True, type='linear', range=xranges[t],
                tickmode='array', tickvals=[1, 100, 200, 300], ticktext=['0', '100', '200', '300'],
                titlefont=dict(size=18), tickfont=dict(size=14),
            )
            
            fig.update_yaxes(
                title=attr_labels[irow] if icol == 0 else None, type=settings[attr]['type'], range=settings[attr]['range'],
                row=irow+1, col=icol+1, mirror=True,
                tickmode='array', tickvals=settings[attr]['vals'], ticktext=settings[attr]['text'], showticklabels=True if icol == 0 else False,
                titlefont=dict(size=18), tickfont=dict(size=14),
            )
            
    for i in range(len(attrs)):
        fig.update_yaxes(titlefont=dict(size=18))

        
    width, height = 1000, 1000
    
    fig.update_layout(
        width=width, height=height,
        margin=dict(t=10, b=70, l=60, r=10),
        legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=.5),
    )

    fig.write_image(f"{FIGURE_DIR}/sedov_explosion_1d.svg", width=width, height=height)
    fig.write_image(f"{FIGURE_DIR}/sedov_explosion_1d.png", width=width, height=height, scale=3)
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_lineouts()