# Sod's shock tube

## TOC
* [Packages & Settings & Global Functions](#packages)
* [Analytical Solution](#analytical-solution)
* [Voxel Traversal](#voxel-traversal)
* [Snapshots](#snapshots)
* [Loading Snapshots](#loading-snapshots)
* [Linear Shock Tube Plots](#linear-plots)
* [Spherical Shock Tube Plots](#spherical-plots)

## Packages & Settings & Global Functions  <a class="anchor" id="packages"></a>

In [None]:
from pyrhyme import PyRhyme

from pathlib import Path

# !pip install numpy
import numpy as np

# !pip install astropy
from astropy import units as U
from astropy import constants as C

# !pip install plotly
import plotly
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
from plotly.subplots import make_subplots
pio.templates.default = "simple_white"

from math import pi

import csv

FIGURE_DIR = './figures'
Path(FIGURE_DIR).mkdir(parents=True, exist_ok=True)

def time_to_snap_id(t, ds):
    for i in range(0, ds.dataset.num_of_snapshots):
        ds.dataset.jump_to(i)
        snap_time = ds.dataset.time
        if snap_time > t:
            if i == 0:
                return i
            
            ds.dataset.jump_to(i-1)
            prev_time = ds.dataset.time
            
            return i if abs(t - snap_time) < abs(t - prev_time) else i - 1
        
    return ds.dataset.num_of_snapshots - 1


def load_rhyme(snap_id, rhyme, n_dim, gamma=7./5.):
    snap = {}
    
    rhyme.dataset.jump_to(snap_id)
    domain = rhyme.dataset.problem_domain
    
    v = rhyme.load_variables(silent=True)
    
    r_unit = rhyme.dataset.active['h5']['attrs']['density_unit']
    v_unit = rhyme.dataset.active['h5']['attrs']['velocity_unit']
    p_unit = rhyme.dataset.active['h5']['attrs']['pressure_unit']
    e_unit = p_unit * rhyme.dataset.active['h5']['attrs']['length_unit']**3

    snap['rho'] = (v['rho'][0].reshape(domain).squeeze() * r_unit).to(U.kg / U.m**3).value
    snap['vx'] = (v['rho_u'][0].reshape(domain).squeeze() / snap['rho'] * v_unit).to(U.m / U.s).value
    if n_dim > 1:
        snap['vy'] = (v['rho_v'][0].reshape(domain).squeeze() / snap['rho'] * v_unit).to(U.m / U.s).value
    else:
        snap['vy'] = 0 if 'rho_v' in v else np.zeros(domain).squeeze()
    if n_dim > 2:
        snap['vz'] = (v['rho_w'][0].reshape(domain).squeeze() / snap['rho'] * v_unit).to(U.m / U.s).value if 'rho_w' in v else np.zeros(domain).squeeze()
    else:
        snap['vz'] = 0
    snap['e_tot'] = (v['e_tot'][0].reshape(domain).squeeze() * e_unit).to(U.kg * U.m**2 / U.s**2).value
    
    snap['v2'] = snap['vx']**2 + snap['vy']**2 + snap['vz']**2
    snap['|v|'] = np.sqrt(snap['v2'])
    snap['p'] = (snap['e_tot'] - .5 * snap['rho'] * snap['v2']) * (gamma - 1)
    snap['M'] = snap['|v|'] / np.sqrt(gamma * snap['p'] / snap['rho'])

    return snap

## Analytical Solution  <a class="anchor" id="analytical-solution"></a>

In [None]:
def load_analytical_solution():
    prefix = 'analytical_solutions'
    result = {}
    
    for t, path in [(0.1, f'{prefix}/sod_s_shock_10.csv'), (0.2, f'{prefix}/sod_s_shock_20.csv'), (0.3, f'{prefix}/sod_s_shock_30.csv'), (0.4, f'{prefix}/sod_s_shock_40.csv')]:
        x, rho, vx, p = [], [], [], []
        
        with open(path, 'r') as file:
            reader = csv.reader(file)
            next(reader, None)
            for row in reader:
                x.append(float(row[0]))
                rho.append(float(row[1]))
                vx.append(float(row[2]))
                p.append(float(row[3]))
                
        result[t] = {'x': np.array(x), 'rho': np.array(rho), 'vx': np.array(vx), 'p': np.array(p)}
        
    return result

if __name__ == '__main__' and '__file__' not in globals():
    SOLUTION = load_analytical_solution()

## Voxel Traversal  <a class="anchor" id="voxel-traversal"></a>

We use Amanatides & Woo (1987) Fast Voxel Traversal Algorithm for Ray Tracing

In [None]:
def ray_tracing(u, v, grid):
    """
    param: u: origin 3d vector (in grid unit)
    param: v: direction 3d vector (in grid unit)
    param: grid: numpy 3d array
    """
    result = []
    
    v = np.array(v) / np.linalg.norm(v)
    boxsize = grid.shape
    cell = np.floor(u)
    
#     print(f"u: {u}, v: {v}, box: {boxsize}, init cell: {cell}")
    
    voxel_size = np.zeros((grid.ndim))
    voxel_max = np.zeros((grid.ndim))
    t_voxel = np.zeros((grid.ndim))
    t_delta = np.zeros((grid.ndim))
    t_max = np.zeros((grid.ndim))
    step = np.zeros((grid.ndim))
    
    for i in range(grid.ndim):
        if v[i] >= 0:
            t_voxel[i] = cell[i] / boxsize[i]
            step[i] = 1
        else:
            t_voxel[i] = (cell[i] - 1) / boxsize[i]
            step[i] = -1
            
        voxel_max[i] = t_voxel[i] * boxsize[i]
        t_max[i] = (voxel_max[i] - u[i]) / v[i] if v[i] != 0 else boxsize[i]
        voxel_size[i] = 1
        t_delta[i] = voxel_size[i] / abs(v[i]) if v[i] != 0 else 0
    
    
    if grid.ndim == 2:
        while (0 <= cell[0] < boxsize[0] and 0 <= cell[1] < boxsize[1]):
            result.append([int(cell[0])-1, int(cell[1])-1])

            if t_max[0] < t_max[1]:
                cell[0] += step[0]
                t_max[0] += t_delta[0]
            else:
                cell[1] += step[1]
                t_max[1] += t_delta[1]
    elif grid.ndim == 3:
        while (0 <= cell[0] < boxsize[0] and 0 <= cell[1] < boxsize[1] and 0 <= cell[2] < boxsize[2]):
            result.append([int(cell[0])-1, int(cell[1])-1, int(cell[2])-1])

            if t_max[0] < t_max[1]:
                if t_max[0] < t_max[2]:
                    cell[0] += step[0]
                    t_max[0] += t_delta[0]
                else:
                    cell[2] += step[2]
                    t_max[2] += t_delta[2]
            else:
                if t_max[1] < t_max[2]:
                    cell[1] += step[1]
                    t_max[1] += t_delta[1]
                else:
                    cell[2] += step[2]
                    t_max[2] += t_delta[2]
                
    return np.array(result)


def plot_ray_tracing(u, v, grid):
    r = ray_tracing(u, v, grid)
    
    fig = go.Figure()
    
    if grid.ndim == 2:
        fig.add_trace(go.Scatter(
            x=[rr[0] for rr in r], y=[rr[1] for rr in r],
            mode='lines+markers',
            line=dict(color='red', width=.5),
            marker=dict(size=3),
        ))

        fig.update_layout(
            xaxis=dict(showgrid=True, showline=True, mirror=True,),
            yaxis=dict(showgrid=True, showline=True, mirror=True,),
        )
    elif grid.ndim == 3:
        fig.add_trace(go.Scatter3d(
            x=[rr[0] for rr in r], y=[rr[1] for rr in r], z=[rr[2] for rr in r],
            mode='lines+markers',
            line=dict(color='red', width=.5),
            marker=dict(size=3),
        ))

        fig.update_layout(
            scene=dict(
                xaxis=dict(showgrid=True, showline=True,),
                yaxis=dict(showgrid=True, showline=True,),
                zaxis=dict(showgrid=True, showline=True,),
            )
        )
        
    fig.update_layout(
        width=600, height=600,
        margin=dict(t=10, b=10, r=10, l=10),
    )
    
    fig.show()
    

def convert_ray_tracing_to_x_y(u, v, grid, points):
    result = []
    
    v = np.array(v) / np.linalg.norm(v)
    
    x = [np.dot(p - u, v) for p in points]
    
    if grid.ndim == 2:
        y = [grid[p[0], p[1]] for p in points]
    elif grid.ndim == 3:
        y = [grid[p[0], p[1], p[2]] for p in points]
    
    return np.array(x), np.array(y)


def voxel_traversal(u, v, grid):
    points = ray_tracing(u, v, grid)
    x, y = convert_ray_tracing_to_x_y(u, v, grid, points)
    
    return x, y, points


def plot_points(u, v, grid):
    xx, yy, _ = voxel_traversal(u, v, grid)
    
    fig = go.Figure()
    
    fig.add_trace(go.Heatmap(z=grid[:, :, u[2]] if grid.ndim == 3 else grid, colorscale='Spectral'))
    
    tx = np.cos(np.arctan(abs(v[1] / v[0])))
    ty = np.sin(np.arctan(abs(v[1] / v[0])))
    x = [u[0] + tx * x * v[0] for x in xx]
    y = [u[1] + ty * x * v[1] + y for x, y in zip(xx, yy)]
    
    fig.add_trace(go.Scatter(
        x=x, y=y, mode='lines',
        line=dict(color='grey', width=1.5),
    ))
    
    fig.update_layout(
        width=600, height=600,
        xaxis=dict(mirror=True,),
        yaxis=dict(mirror=True,),
    )
    
    fig.show()


def _make_grid(res):
    res = np.array(res)
    grid = np.zeros(res)
    
    lb = -(res / 2 / 10)
    ub = (res / 2 / 10)
    
    if len(res) == 2:
        x, y = np.mgrid[lb[0]:ub[0]:0.1, lb[1]:ub[1]:0.1]
        grid = np.sin(x**2 + y**2) # + np.sqrt(x**2 + y**2)
    elif len(res) == 3:
        x, y, z = np.mgrid[lb[0]:ub[0]:0.1, lb[1]:ub[1]:0.1, lb[2]:ub[2]:0.1]
        grid = np.sin(x**2 + y**2 + z**2) # + np.sqrt(x**2 + y**2 + z**2)
    else:
        print('[err] wrong dimensionality')
        return None
    
    return grid
    
if __name__ == '__main__' and '__file__' not in globals():
#     plot_ray_tracing(u=[32, 32, 32], v=[1, 1, 0.5], grid=_make_grid([64, 64, 64]))
#     plot_ray_tracing(u=[32, 32], v=[0, 1], grid=_make_grid([64, 64]))
#     plot_points(u=[32, 32, 32], v=[1, 1, 0], grid=_make_grid([64, 64, 64]))
#     plot_points(u=[32, 32], v=[1, 1], grid=_make_grid([64, 64]))
    pass

## Snapshots  <a class="anchor" id="snapshots"></a>

In [None]:
SNAPSHOTS = {
    'linear': {
        '1d': {
            'path': 'sod_s_shock_1d/sod_s_shock_1d/sod_s_shock_1d-000001.chombo.h5',
        },
        '2d': {
            'path': 'sod_s_shock_2d/sod_s_shock_2d/sod_s_shock_2d-000001.chombo.h5',
        },
        '3d': {
            'path': 'sod_s_shock_3d/sod_s_shock_3d/sod_s_shock_3d-000001.chombo.h5',
        },
    },
    'spherical': {
        '2d': {
            'path': 'sod_s_shock_sphere_2d/sod_s_shock_sphere_2d/sod_s_shock_sphere_2d-000001.chombo.h5',
        },
        '3d': {
            'path': 'sod_s_shock_sphere_3d/sod_s_shock_sphere_3d/sod_s_shock_sphere_3d-000001.chombo.h5',
        },
    },
}

# Check if snapshots are accessible
for kind in SNAPSHOTS.values():
    for res, snap in kind.items():
        if not Path(snap['path']).is_file():
            print(f"[err] Cannot access {res} snap @ {snap['path']}")
        else:
            print(f"[accessible] {res} snap @ {snap['path']}")

## Loading Snapshots  <a class="anchor" id="loading-snapshots"></a>

In [None]:
def load_snapshot(times):
    result = {}
    for kind, snaps in SNAPSHOTS.items():
        result[kind] = {}
        
        for res, snapinfo in snaps.items():
            result[kind][res] = {}
            r = PyRhyme(snapinfo['path'])
            
            for t in times:
                result[kind][res][t] = load_rhyme(time_to_snap_id(t, r), r, int(res[:1]))
        
        r.dataset.close_current()
        r.dataset.clean_all()
        
        del r
    
    return result

if __name__ == '__main__' and '__file__' not in globals():
    DATA = load_snapshot([0.1, 0.2, 0.3, 0.4])

## Linear Shock Tube Plots  <a class="anchor" id="linear-plots"></a>

In [None]:
def plot_linear_shock_tube_maps():
    times = DATA['linear']['2d'].keys()
    
    for t in times:
        for attr in ['rho', 'vx', 'vy', 'v2', 'M', 'e_tot']:
            for res, snap in DATA['linear'].items():
                fig = go.Figure()
                if res == '1d':
                    continue
                z = snap[t][attr] if res == '2d' else snap[t][attr][64, :, :]
                fig.add_trace(go.Heatmap(z=z, colorscale='Spectral',))
                
                fig.update_layout(
                    title=f"{res} {attr} @ {t}",
                    width=600, height=600,
                    xaxis=dict(mirror=True,),
                    yaxis=dict(mirror=True,),
                )

                fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_linear_shock_tube_maps()

In [None]:
def plot_linear_shock_tube_lineouts(
    times=[0.1, 0.2, 0.3, 0.4],
    attrs=['rho', 'vx', 'p'],
    attr_labels=['ρ', 'v<sub>x</sub>', 'p']
):
    colors = px.colors.sample_colorscale(px.colors.diverging.Portland_r, np.linspace(0, 1, len(DATA['linear'].keys())+1))
    
    fig = make_subplots(
        len(attrs), len(times),
        row_heights=[1]*len(attrs), column_widths=[1]*len(times),
        vertical_spacing=0.025, horizontal_spacing=0.025,
        shared_xaxes=True, shared_yaxes=False,
        x_title='x / L<sub>box</sub>',
        subplot_titles=[f"t = {t}" for t in times],
    )
    
    for i in range(len(times)):
        fig.layout.annotations[i]["font"] = dict(size=18, color='black')
        fig.layout.annotations[i]["y"] += 0.0125

    fig.layout.annotations[len(times)]["font"] = dict(size=18, color='black')
    fig.layout.annotations[len(times)]["y"] -= 0.0125
    
    for irow, attr in enumerate(attrs):
        for icol, t in enumerate(times):
            fig.add_trace(go.Scatter(
                x=SOLUTION[t]['x'][:-1] * 256, y=SOLUTION[t][attr][:-1], name='Sol.',
                mode='lines', line=dict(color='grey', width=2), opacity=0.75,
#                 mode='markers', marker=dict(color='grey', size=4),
                showlegend=True if irow == icol == 0 else False
            ), irow+1, icol+1)
            
            for ci, (res, snap) in enumerate(DATA['linear'].items()):
                if res == '1d':
                    x = list(range(len(snap[t][attr][1:])))
                    y = snap[t][attr][1:]
                elif res == '2d':
                    u, v = [127.5, 1], [0, 1]
                    x, y, _ = voxel_traversal(u, v, snap[t][attr])
                elif res == '3d':
                    u, v = [127.5, 127.5, 1], [0, 0, 1]
                    x, y, _ = voxel_traversal(u, v, snap[t][attr])
                
                fig.add_trace(go.Scatter(
                    x=np.array(x), y=y, name=f"{res}",
                    mode='lines', line=dict(color=colors[ci], width=1),
                    showlegend=True if irow == icol == 0 else False
                ), irow+1, icol+1)
                
            fig.update_xaxes(
                row=irow+1, col=icol+1, mirror=True,
                tickmode='array', tickvals=[1, 128, 254], ticktext=['0', '.5', '1'],
                titlefont=dict(size=18), tickfont=dict(size=14),
            )
            
            tickvals = np.linspace(0, 2, 11)
            fig.update_yaxes(
                title=attr_labels[irow] if icol == 0 else None,
                row=irow+1, col=icol+1, mirror=True,
                tickmode='array', tickvals=tickvals, showticklabels=True if icol == 0 else False,
                titlefont=dict(size=18), tickfont=dict(size=14),
            )
        
    width, height = 1000, 820
    
    fig.update_layout(
        width=width, height=height,
        margin=dict(t=10, b=70, l=60, r=10),
        legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=.5),
    )

    fig.write_image(f"{FIGURE_DIR}/sod_s_shock_tube.svg", width=width, height=height)
    fig.write_image(f"{FIGURE_DIR}/sod_s_shock_tube.png", width=width, height=height, scale=3)
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_linear_shock_tube_lineouts(times=[0.1, 0.2, 0.3, 0.4], attrs=['rho', 'vx', 'p'], attr_labels=['ρ', 'v<sub>x</sub>', 'p'])

## Spherical Shock Tube Plots  <a class="anchor" id="spherical-plots"></a>

In [None]:
def plot_spherical_shock_tube_maps(
    attr='rho', attr_label='ρ', times=[0.1, 0.2, 0.3, 0.4], resolutions=['2d', '3d']):
    zmin, zmax, nticks = .125, .375, 3
    fig = make_subplots(
        len(resolutions), len(times),
        row_heights=[1]*len(resolutions), column_widths=[1]*len(times),
        vertical_spacing=0.035, horizontal_spacing=0.0175,
        shared_xaxes=True, shared_yaxes=False,
        x_title='x / L<sub>box</sub>', y_title='y / L<sub>box</sub>',
        subplot_titles=[f"t = {t}" for t in times],
    )
    
    for i in range(len(times)):
        fig.layout.annotations[i]["font"] = dict(size=18, color='black')
        fig.layout.annotations[i]["y"] += 0.0125

    fig.layout.annotations[len(times)]["font"] = dict(size=18, color='black')
    fig.layout.annotations[len(times)]["y"] -= 0.0125
    fig.layout.annotations[len(times)+1]["font"] = dict(size=18, color='black')
    fig.layout.annotations[len(times)+1]["x"] -= 0.0125
    
    for irow, res in enumerate(resolutions):
        for icol, t in enumerate(times):

            z = DATA['spherical'][res][t][attr] if res == '2d' else DATA['spherical'][res][t][attr][128, :, :]
            
            fig.add_trace(go.Heatmap(
                z=z, colorscale='Sunset', reversescale=True,
                zauto=False, zmin=zmin, zmax=zmax,
                colorbar=dict(
                    title=attr_label, titleside='bottom', titlefont=dict(size=18),
                    thicknessmode='pixels', thickness=10, tickfont=dict(size=18),
                    tickvals=list(np.linspace(zmin, zmax, nticks)),
                    x=1.02, y=0.5, len=0.8, xanchor='left', yanchor='middle',
                ),
                showscale=True if irow == icol == 0 else False
            ), irow+1, icol+1)
            
            axisref = '' if irow == icol == 0 else f'{irow * len(times) + icol + 1}'
            fig.add_annotation(
                x=250, y=7, xref=f'x{axisref}', yref=f'y{axisref}', xanchor='right', yanchor='bottom',
                text=res, showarrow=False, font=dict(color="white", size=18),
            )
            
            fig.update_xaxes(
                row=irow+1, col=icol+1, mirror=True,
                showticklabels=True if irow == len(resolutions)-1 else False,
                tickmode='array', tickvals=[1, 128, 255], ticktext=['0', '.5', '1'],
                titlefont=dict(size=18), tickfont=dict(size=14),
            )
            
            fig.update_yaxes(
                row=irow+1, col=icol+1, mirror=True,
                showticklabels=True if icol == 0 else False,
                tickmode='array', tickvals=[1, 128, 255], ticktext=['0', '.5', '1'],
                titlefont=dict(size=18), tickfont=dict(size=14),
            )
            
    width, height = 1000, 535
                
    fig.update_layout(
        width=width, height=height,
        margin=dict(t=60, b=70, l=70, r=10),
    )
    
    fig.write_image(f"{FIGURE_DIR}/sod_s_sphere_shock_maps.svg", width=width, height=height)
    fig.write_image(f"{FIGURE_DIR}/sod_s_sphere_shock_maps.png", width=width, height=height, scale=3)
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_spherical_shock_tube_maps()

In [None]:
def plot_spherical_shock_tube_lineouts_2d(
    times=[0.1, 0.2, 0.3, 0.4],
    attrs=['rho', '|v|', 'p'],
    attr_labels=['ρ', 'v<sub>x</sub>', 'p']
):
    u = [127.5, 127.5]
    θs = [(0, '0'), (pi/3, 'π/3'), (2*pi/3, '2π/3'), (pi, 'π'), (3*pi/2, '3π/2'), (7*pi/4, '7π/4')]
    colors = px.colors.sample_colorscale(px.colors.diverging.Portland_r, np.linspace(0, 1, len(θs)))
    
    tickvals = {
        'rho': np.linspace(0, 0.5, 6),
        '|v|': np.linspace(0, 2, 5),
        'p': np.linspace(0, 0.4, 5),
        'e_tot': np.linspace(0, 1, 11),
    }
    
    tickrange = {
        'rho': [0, .4],
        '|v|': [0, 1.2],
        'p': [0, 0.4],
        'e_tot': [0, 2],
    }
    
    fig = make_subplots(
        len(attrs), len(times),
        row_heights=[1]*len(attrs), column_widths=[1]*len(times),
        vertical_spacing=0.035, horizontal_spacing=0.0175,
        shared_xaxes=True, shared_yaxes=False,
        x_title='x / L<sub>box</sub>',
        subplot_titles=[f"t = {t}" for t in times],
    )
    
    for i in range(len(times)):
        fig.layout.annotations[i]["font"] = dict(size=18, color='black')
        fig.layout.annotations[i]["y"] += 0.0125
    
    fig.layout.annotations[len(times)]["font"] = dict(size=18, color='black')
    fig.layout.annotations[len(times)]["y"] -= 0.0125

    
    for irow, (attr, attr_label) in enumerate(zip(attrs, attr_labels)):
        for icol, t in enumerate(times):
            snap = DATA['spherical']['2d'][t]
        
            for i, (θ, θlabel) in enumerate(θs):
                title = f"θ = {θlabel}"
        
                v = [np.cos(θ), np.sin(θ)]
                x, y, _ = voxel_traversal(u, v, snap[attr])
                
                fig.add_trace(go.Scatter(
                    x=x, y=y, name=title,
                    mode='lines', line=dict(color=colors[i], width=0.75),
                    showlegend=True if irow == icol == 1 else False,
                ), irow+1, icol+1)
                
                fig.update_xaxes(
                    row=irow+1, col=icol+1, title=None, mirror=True,
                    tickmode='array', tickvals=[0, 64, 128, 192], ticktext=['0', '.25', '.5', '.75'],
                    titlefont=dict(size=18), tickfont=dict(size=14),
                    showticklabels=True if irow == len(attrs)-1 else False
                )

                fig.update_yaxes(
                    row=irow+1, col=icol+1, mirror=True, range=tickrange[attr],
                    titlefont=dict(size=18), title=attr_label if icol == 0 else None,
                    tickmode='array', tickfont=dict(size=14), tickvals=tickvals[attr],
                    showticklabels=True if icol == 0 else False,
                )
                
    for i in range(len(attrs)):
        fig.update_yaxes(row=i+1, col=1, titlefont=dict(size=18))

    width, height = 1000, 800

    fig.update_layout(
        width=width, height=height,
    )

    fig.write_image(f"{FIGURE_DIR}/sod_s_sphere_shock_2d_lineouts.svg", width=width, height=height)
    fig.write_image(f"{FIGURE_DIR}/sod_s_sphere_shock_2d_lineouts.png", width=width, height=height, scale=3)
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_spherical_shock_tube_lineouts_2d()

In [None]:
def plot_spherical_shock_tube_lineouts_3d(
    times=[0.1, 0.2, 0.3, 0.4],
    attrs=['rho', '|v|', 'p'],
    attr_labels=['ρ', '|v|', 'p']
):
    u = [127.5, 127.5, 127.5]
    φs = [(0, '0'), (pi/3, 'π/3'), (2*pi/3, '2π/3'), (pi, 'π'), (3*pi/2, '3π/2'), (7*pi/4, '7π/4')]
    θs = [(0, '0'), (pi/4, 'π/4'), (pi/2, 'π/2'), (2*pi/3, '2π/3'), (pi, 'π')]
    colors = px.colors.sample_colorscale(px.colors.diverging.Portland_r, np.linspace(0, 1, len(φs) * len(θs)))
    
    tickvals = {
        'rho': np.linspace(0, 0.5, 6),
        '|v|': np.linspace(0, 2, 5),
        'p': np.linspace(0, 0.4, 5),
        'e_tot': np.linspace(0, 1, 11),
    }
    
    tickrange = {
        'rho': [0, .3],
        '|v|': [0, 1.3],
        'p': [0, 0.2],
        'e_tot': [0, 2.1],
    }
    
    fig = make_subplots(
        len(attrs), len(times),
        row_heights=[1]*len(attrs), column_widths=[1]*len(times),
        vertical_spacing=0.035, horizontal_spacing=0.0175,
        shared_xaxes=True, shared_yaxes=False,
        x_title='x / L<sub>box</sub>',
        subplot_titles=[f"t = {t}" for t in times],
    )
    
    for i in range(len(times)):
        fig.layout.annotations[i]["font"] = dict(size=18, color='black')
        fig.layout.annotations[i]["y"] += 0.0125
    
    fig.layout.annotations[len(times)]["font"] = dict(size=18, color='black')
    fig.layout.annotations[len(times)]["y"] -= 0.0125
    

    for irow, (attr, attr_label) in enumerate(zip(attrs, attr_labels)):
        for icol, t in enumerate(times):
            snap = DATA['spherical']['3d'][t]
        
            for ii, (φ, φ_str) in enumerate(φs):
                iname = f"φ = {φ_str}"
                
                for ij, (θ, θ_str) in enumerate(θs):
                    jname = f"θ = {θ_str}"
                    
                    v = [np.sin(θ) * np.cos(φ), np.sin(θ) * np.sin(φ), np.cos(θ)]
                    x, y, _ = voxel_traversal(u, v, snap[attr])

                    fig.add_trace(go.Scatter(
                        x=x, y=y, name=f"{iname}, {jname}",
                        mode='lines', line=dict(color=colors[ii * len(θs) + ij], width=0.75),
                        showlegend=True if irow == icol == 1 else False,
                    ), irow+1, icol+1)
                    
                    fig.update_xaxes(
                        row=irow+1, col=icol+1, title=None, mirror=True,
                        tickmode='array', tickvals=[0, 64, 128, 192], ticktext=['0', '.25', '.5', '.75'],
                        titlefont=dict(size=18), tickfont=dict(size=14),
                        showticklabels=True if irow == len(attrs)-1 else False
                    )

                    fig.update_yaxes(
                        row=irow+1, col=icol+1, mirror=True, range=tickrange[attr],
                        titlefont=dict(size=18), title=attr_label if icol == 0 else None,
                        tickmode='array', tickfont=dict(size=14), tickvals=tickvals[attr],
                        showticklabels=True if icol == 0 else False,
                    )

    for i in range(len(attrs)):
        fig.update_yaxes(row=i+1, col=1, titlefont=dict(size=18))

    width, height = 1000, 800

    fig.update_layout(
        width=width, height=height,
        margin=dict(t=70, b=70, l=60, r=10),
    )
    
    fig.write_image(f"{FIGURE_DIR}/sod_s_sphere_shock_3d_lineouts.svg", width=width, height=height)
    fig.write_image(f"{FIGURE_DIR}/sod_s_sphere_shock_3d_lineouts.png", width=width, height=height, scale=3)
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_spherical_shock_tube_lineouts_3d()