# Sod's shock tube

## TOC
* [Packages & Settings & Global Functions](#packages)
* [Voxel Traversal](#voxel-traversal)
* [Snapshots](#snapshots)
* [Loading Snapshots](#loading-snapshots)
* [Linear Shock Tube Plots](#linear-plots)
* [Spherical Shock Tube Plots](#spherical-plots)

## Packages & Settings & Global Functions  <a class="anchor" id="packages"></a>

In [None]:
from pyrhyme import PyRhyme

from pathlib import Path

# !pip install numpy
import numpy as np

# !pip install astropy
from astropy import units as U
from astropy import constants as C

# !pip install plotly
import plotly
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
from plotly.subplots import make_subplots
pio.templates.default = "simple_white"

from math import pi


def time_to_snap_id(t, ds):
    for i in range(0, ds.dataset.num_of_snapshots):
        ds.dataset.jump_to(i)
        snap_time = ds.dataset.time
        if snap_time > t:
            if i == 0:
                return i
            
            ds.dataset.jump_to(i-1)
            prev_time = ds.dataset.time
            
            return i if abs(t - snap_time) < abs(t - prev_time) else i - 1
        
    return ds.dataset.num_of_snapshots - 1


def load_rhyme(snap_id, rhyme, n_dim, gamma=7./5.):
    snap = {}
    
    rhyme.dataset.jump_to(snap_id)
    domain = rhyme.dataset.problem_domain
    
    if domain[-1] == 1:
        domain = domain[:-1]
    
    v = rhyme.load_variables(silent=True)
    
    r_unit = rhyme.dataset.active['h5']['attrs']['density_unit']
    v_unit = rhyme.dataset.active['h5']['attrs']['velocity_unit']
    p_unit = rhyme.dataset.active['h5']['attrs']['pressure_unit']
    e_unit = p_unit * rhyme.dataset.active['h5']['attrs']['length_unit']**3

    snap['rho'] = (v['rho'][0].reshape(domain) * r_unit).to(U.g / U.cm**3).value
    snap['vx'] = (v['rho_u'][0].reshape(domain) / snap['rho'] * v_unit).to(U.cm / U.s).value
    snap['vy'] = (v['rho_v'][0].reshape(domain) / snap['rho'] * v_unit).to(U.cm / U.s).value
    snap['vz'] = (v['rho_w'][0].reshape(domain) / snap['rho'] * v_unit).to(U.cm / U.s).value if 'rho_w' in v else 0
    snap['e_tot'] = (v['e_tot'][0].reshape(domain) * e_unit).to(U.g * U.cm**2 / U.s**2).value
    
    snap['v2'] = snap['vx']**2 + snap['vy']**2 + snap['vz']**2
    snap['p'] = (snap['e_tot'] - .5 * snap['rho'] * snap['v2']) * (gamma - 1)
    snap['M'] = np.sqrt(snap['v2']) / np.sqrt(gamma * snap['p'] / snap['rho'])

    return snap

## Voxel Traversal  <a class="anchor" id="voxel-traversal"></a>

We use Amanatides & Woo (1987) Fast Voxel Traversal Algorithm for Ray Tracing

In [None]:
def ray_tracing(u, v, grid):
    """
    param: u: origin 3d vector (in grid unit)
    param: v: direction 3d vector (in grid unit)
    param: grid: numpy 3d array
    """
    result = []
    
    v = np.array(v) / np.linalg.norm(v)
    boxsize = grid.shape
    cell = np.floor(u)
    
    print(f"u: {u}, v: {v}, box: {boxsize}, init cell: {cell}")
    
    voxel_size = np.zeros((grid.ndim))
    voxel_max = np.zeros((grid.ndim))
    t_voxel = np.zeros((grid.ndim))
    t_delta = np.zeros((grid.ndim))
    t_max = np.zeros((grid.ndim))
    step = np.zeros((grid.ndim))
    
    for i in range(grid.ndim):
        if v[i] >= 0:
            t_voxel[i] = cell[i] / boxsize[i]
            step[i] = 1
        else:
            t_voxel[i] = (cell[i] - 1) / boxsize[i]
            step[i] = -1
            
        voxel_max[i] = t_voxel[i] * boxsize[i]
        t_max[i] = (voxel_max[i] - u[i]) / v[i] if v[i] != 0 else boxsize[i]
        voxel_size[i] = 1
        t_delta[i] = voxel_size[i] / abs(v[i]) if v[i] != 0 else 0
    
    
    if grid.ndim == 2:
        while (1 <= cell[0] <= boxsize[0] and 1 <= cell[1] <= boxsize[1]):
            result.append([int(cell[0])-1, int(cell[1])-1])

            if t_max[0] < t_max[1]:
                cell[0] += step[0]
                t_max[0] += t_delta[0]
            else:
                cell[1] += step[1]
                t_max[1] += t_delta[1]
    elif grid.ndim == 3:
        while (1 <= cell[0] <= boxsize[0] and 1 <= cell[1] <= boxsize[1] and 1 <= cell[2] <= boxsize[2]):
            result.append([int(cell[0])-1, int(cell[1])-1, int(cell[2])-1])

            if t_max[0] < t_max[1]:
                if t_max[0] < t_max[2]:
                    cell[0] += step[0]
                    t_max[0] += t_delta[0]
                else:
                    cell[2] += step[2]
                    t_max[2] += t_delta[2]
            else:
                if t_max[1] < t_max[2]:
                    cell[1] += step[1]
                    t_max[1] += t_delta[1]
                else:
                    cell[2] += step[2]
                    t_max[2] += t_delta[2]
                
    return np.array(result)


def plot_ray_tracing(u, v, grid):
    r = ray_tracing(u, v, grid)
    
    fig = go.Figure()
    
    if grid.ndim == 2:
        fig.add_trace(go.Scatter(
            x=[rr[0] for rr in r], y=[rr[1] for rr in r],
            mode='lines+markers',
            line=dict(color='red', width=.5),
            marker=dict(size=3),
        ))

        fig.update_layout(
            xaxis=dict(showgrid=True, showline=True, mirror=True,),
            yaxis=dict(showgrid=True, showline=True, mirror=True,),
        )
    elif grid.ndim == 3:
        fig.add_trace(go.Scatter3d(
            x=[rr[0] for rr in r], y=[rr[1] for rr in r], z=[rr[2] for rr in r],
            mode='lines+markers',
            line=dict(color='red', width=.5),
            marker=dict(size=3),
        ))

        fig.update_layout(
            scene=dict(
                xaxis=dict(showgrid=True, showline=True,),
                yaxis=dict(showgrid=True, showline=True,),
                zaxis=dict(showgrid=True, showline=True,),
            )
        )
        
    fig.update_layout(
        width=600, height=600,
        margin=dict(t=10, b=10, r=10, l=10),
    )
    
    fig.show()
    

def convert_ray_tracing_to_x_y(u, v, grid, points):
    result = []
    
    v = np.array(v) / np.linalg.norm(v)
    
    x = [np.dot(p - u, v) for p in points]
    
    if grid.ndim == 2:
        y = [grid[p[0], p[1]] for p in points]
    elif grid.ndim == 3:
        y = [grid[p[0], p[1], p[2]] for p in points]
    
    return np.array(x), np.array(y)


def voxel_traversal(u, v, grid):
    points = ray_tracing(u, v, grid)
    x, y = convert_ray_tracing_to_x_y(u, v, grid, points)
    
    return x, y, points


def plot_points(u, v, grid):
    xx, yy, _ = voxel_traversal(u, v, grid)
    
    fig = go.Figure()
    
    fig.add_trace(go.Heatmap(z=grid[:, :, u[2]] if grid.ndim == 3 else grid, colorscale='Spectral'))
    
    tx = np.cos(np.arctan(abs(v[1] / v[0])))
    ty = np.sin(np.arctan(abs(v[1] / v[0])))
    x = [u[0] + tx * x * v[0] for x in xx]
    y = [u[1] + ty * x * v[1] + y for x, y in zip(xx, yy)]
    
    fig.add_trace(go.Scatter(
        x=x, y=y, mode='lines',
        line=dict(color='grey', width=1.5),
    ))
    
    fig.update_layout(
        width=600, height=600,
        xaxis=dict(mirror=True,),
        yaxis=dict(mirror=True,),
    )
    
    fig.show()


def _make_grid(res):
    res = np.array(res)
    grid = np.zeros(res)
    
    lb = -(res / 2 / 10)
    ub = (res / 2 / 10)
    
    if len(res) == 2:
        x, y = np.mgrid[lb[0]:ub[0]:0.1, lb[1]:ub[1]:0.1]
        grid = np.sin(x**2 + y**2) # + np.sqrt(x**2 + y**2)
    elif len(res) == 3:
        x, y, z = np.mgrid[lb[0]:ub[0]:0.1, lb[1]:ub[1]:0.1, lb[2]:ub[2]:0.1]
        grid = np.sin(x**2 + y**2 + z**2) # + np.sqrt(x**2 + y**2 + z**2)
    else:
        print('[err] wrong dimensionality')
        return None
    
    return grid
    
# if __name__ == '__main__' and '__file__' not in globals():
#     plot_ray_tracing(u=[32, 32, 32], v=[1, 1, 0.5], grid=_make_grid([64, 64, 64]))
#     plot_ray_tracing(u=[32, 32], v=[0, 1], grid=_make_grid([64, 64]))
#     plot_points(u=[32, 32, 32], v=[1, 1, 0], grid=_make_grid([64, 64, 64]))
#     plot_points(u=[32, 32], v=[1, 1], grid=_make_grid([64, 64]))

## Snapshots  <a class="anchor" id="snapshots"></a>

In [None]:
SNAPSHOTS = {
    'linear': {
        '2d': {
            'path': 'sod_s_shock_2d/sod_s_shock_2d/sod_s_shock_2d-000001.chombo.h5',
        },
        '3d': {
            'path': 'sod_s_shock_3d/sod_s_shock_3d/sod_s_shock_3d-000001.chombo.h5',
        },
    },
    'spherical': {
        '2d': {
            'path': 'sod_s_shock_sphere_2d/sod_s_shock_sphere_2d/sod_s_shock_sphere_2d-000001.chombo.h5',
        },
        '3d': {
            'path': 'sod_s_shock_sphere_3d/sod_s_shock_sphere_3d/sod_s_shock_sphere_3d-000001.chombo.h5',
        },
    },
}

# Check if snapshots are accessible
for kind in SNAPSHOTS.values():
    for res, snap in kind.items():
        if not Path(snap['path']).is_file():
            print(f"[err] Cannot access {res} snap @ {snap['path']}")
        else:
            print(f"[accessible] {res} snap @ {snap['path']}")

## Loading Snapshots  <a class="anchor" id="loading-snapshots"></a>

In [None]:
def load_snapshot(times):
    result = {}
    for kind, snaps in SNAPSHOTS.items():
        result[kind] = {}
        
        for res, snapinfo in snaps.items():
            result[kind][res] = {}
            r = PyRhyme(snapinfo['path'])
            
            for t in times:
                result[kind][res][t] = load_rhyme(time_to_snap_id(t, r), r, int(res[:1]))
        
        r.dataset.close_current()
        r.dataset.clean_all()
        
        del r
    
    return result

if __name__ == '__main__' and '__file__' not in globals():
    DATA = load_snapshot([0.0, 0.1, 0.2, 0.3, 0.4])

## Linear Shock Tube Plots  <a class="anchor" id="linear-plots"></a>

In [None]:
def plot_linear_shock_tube():
    times = DATA['linear']['2d'].keys()
    
    for t in times:
        for attr in ['rho', 'vx', 'e_tot']:
            fig = go.Figure()

            for res, snap in DATA['linear'].items():
                u = [63.5, 63.5, 1] if res == '3d' else [63.5, 1]
                v = [0, 0, 1] if res == '3d' else [0, 1]
                x, y, _ = voxel_traversal(u, v, snap[t][attr])
                
                fig.add_trace(go.Scatter(x=x, y=y, name=f"{res}-{t}"))
                
            fig.update_layout(
                title=f"{attr} @ {t}",
                width=600, height=600,
                xaxis=dict(mirror=True,),
                yaxis=dict(mirror=True,),
            )
            
            fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_linear_shock_tube()

## Spherical Shock Tube Plots  <a class="anchor" id="spherical-plots"></a>

In [None]:
def plot_spherical_shock_tube_maps():
    times = DATA['spherical']['2d'].keys()
    
    for t in times:
        for attr in ['rho', 'vx', 'vy', 'v2', 'e_tot']:
            for res, snap in DATA['spherical'].items():
                fig = go.Figure()
                z = snap[t][attr] if res == '2d' else snap[t][attr][64, :, :]
                fig.add_trace(go.Heatmap(z=z, colorscale='Spectral',))
                
                fig.update_layout(
                    title=f"{res} {attr} @ {t}",
                    width=600, height=600,
                    xaxis=dict(mirror=True,),
                    yaxis=dict(mirror=True,),
                )

                fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_spherical_shock_tube_maps()

In [None]:
def plot_spherical_shock_tube_lineouts():
    times = DATA['spherical']['2d'].keys()
    
    for t in times:
        for attr in ['rho', 'vx', 'e_tot']:
            fig = go.Figure()

            for res, snap in DATA['spherical'].items():
                u = [63.5, 63.5, 1] if res == '3d' else [63.5, 1]
                v = [0, 0, 1] if res == '3d' else [0, 1]
                x, y, _ = voxel_traversal(u, v, snap[t][attr])
                
                fig.add_trace(go.Scatter(x=x, y=y, name=f"{res}-{t}"))
                
            fig.update_layout(
                title=f"{attr} @ {t}",
                width=600, height=600,
                xaxis=dict(mirror=True,),
                yaxis=dict(mirror=True,),
            )
            
            fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_spherical_shock_tube_lineouts()