# Domain Decomposition & Blocking Visualization

In [None]:
import plotly.graph_objects as go
import numpy as np
from ipywidgets import IntSlider, Layout, HBox, VBox, HTML, interactive_output
from IPython.display import display

In [None]:
def get_cube_wireframe(x0, y0, z0, x1, y1, z1):
    """Generate wireframe coordinates for a cube."""
    x = [x0, x1, x1, x0, x0, None, x0, x1, x1, x0, x0, None, x0, x0, None, x1, x1, None, x1, x1, None, x0, x0]
    y = [y0, y0, y1, y1, y0, None, y0, y0, y1, y1, y0, None, y0, y0, None, y0, y0, None, y1, y1, None, y1, y1]
    z = [z0, z0, z0, z0, z0, None, z1, z1, z1, z1, z1, None, z0, z1, None, z0, z1, None, z0, z1, None, z0, z1]
    return x, y, z

def get_cube_mesh(x0, y0, z0, x1, y1, z1):
    """Generate mesh coordinates for a cube."""
    x = [x0, x1, x1, x0, x0, x1, x1, x0]
    y = [y0, y0, y1, y1, y0, y0, y1, y1]
    z = [z0, z0, z0, z0, z1, z1, z1, z1]
    i = [7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2]
    j = [3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3]
    k = [0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6]
    return x, y, z, i, j, k

In [None]:
style = {"description_width": "initial"}
layout_slider = Layout(width="95%")

## Single-Device Blocking Visualization

Visualization of the domain decomposition strategy for the `single_dev_main` application.

This visualization demonstrates the **Temporal Blocking (Halo Blocking)** strategy on a single device.

- **$I$ (Inner Loop)**: The number of time steps executed per block. This determines the halo thickness required.
- **$\mathbf{B}$ (Block Shape)**: The spatial dimensions of the *total* region (Valid + Halo) read by one CUDA block.
- **$\mathbf{S}$ (Block Dim Valid)**: The spatial dimensions of the *valid* region updated by one CUDA block.
- **$\mathbf{n}$ (Block Count)**: The number of blocks covering the domain.
- **$\mathbf{N}$ (Domain Shape)**: The total simulation domain size ($\mathbf{N} = \mathbf{n} \times \mathbf{S}$).

**Legend**:
- **Red Box**: Total Region ($\mathbf{B}$, includes halo where neighbors exist).
  - Interior blocks: $\mathbf{B} = \mathbf{S} + 2 \times \text{Halo}$
  - Boundary blocks: $\mathbf{B} = \mathbf{S} + \text{Halo}$ (or $\mathbf{S}$ if $n=1$)
- **Blue Dashed Box**: Valid Region ($\mathbf{S}$, updated by the block).


In [None]:
def plot_single_dev_blocking(I, Sx, Sy, Sz, nx, ny, nz):
    halo = I - 1
    
    # Dimensions
    S = np.array([Sx, Sy, Sz]) # Block Dim (Valid)
    n = np.array([nx, ny, nz]) # Grid Dim
    N = S * n # Total Valid Domain
    B = S + 2 * halo # Block Dim (Total, Interior)
    
    # Efficiency Calculation
    phi = np.prod(np.where(n == 1, 1.0, 
                           (1 - 2 / n) * (1 - 2 * halo / B) + 
                           (2 / n) * (1 - halo / B)))
    
    # print(f"Inner Loop (I): {I} => Halo Thickness: {halo}")
    # print(f"n (Grid Dim): {n}")
    # print(f"S (Block Dim Valid): {S}")
    # print(f"B (Block Dim Total, Interior): {B} (Boundary blocks have fewer halos)")
    print(f"N (Total Valid Domain): {N}")
    print()

    print(f"Theoretical Efficiency (Valid/Total Work): {phi:.2%}")
    print()

    fig = go.Figure()

    # Dummy Legend Traces
    fig.add_trace(go.Scatter3d(x=[None], y=[None], z=[None], mode='lines', name='Valid Region (S)', line=dict(color='blue', width=3, dash='dash')))
    fig.add_trace(go.Scatter3d(x=[None], y=[None], z=[None], mode='lines', name='Total Region (B)', line=dict(color='red', width=2)))

    for z in range(nz):
        for y in range(ny):
            for x in range(nx):
                # Valid Region Coords
                vx0, vy0, vz0 = x * Sx, y * Sy, z * Sz
                vx1, vy1, vz1 = vx0 + Sx, vy0 + Sy, vz0 + Sz
                
                # Valid Region Wireframe (Dashed Blue)
                vx, vy, vz = get_cube_wireframe(vx0, vy0, vz0, vx1, vy1, vz1)
                fig.add_trace(go.Scatter3d(
                    x=vx, y=vy, z=vz, mode='lines',
                    line=dict(color='blue', width=3, dash='dash'),
                    showlegend=False, hoverinfo='skip'
                ))
                
                # Total Region Coords (B)
                hx0 = vx0 - halo if x > 0 else vx0
                hy0 = vy0 - halo if y > 0 else vy0
                hz0 = vz0 - halo if z > 0 else vz0
                hx1 = vx1 + halo if x < nx - 1 else vx1
                hy1 = vy1 + halo if y < ny - 1 else vy1
                hz1 = vz1 + halo if z < nz - 1 else vz1
                
                # Total Region Wireframe (Solid Red)
                hx, hy, hz = get_cube_wireframe(hx0, hy0, hz0, hx1, hy1, hz1)
                fig.add_trace(go.Scatter3d(
                    x=hx, y=hy, z=hz, mode='lines',
                    line=dict(color='red', width=2),
                    showlegend=False, hoverinfo='skip'
                ))

                # Total Region Mesh (Transparent Red, for hover)
                mx, my, mz, i, j, k = get_cube_mesh(hx0, hy0, hz0, hx1, hy1, hz1)
                fig.add_trace(go.Mesh3d(
                    x=mx, y=my, z=mz, i=i, j=j, k=k,
                    opacity=0.1, color='red', showlegend=False, hoverinfo='text',
                    text=f'Block ({x},{y},{z})<br>Valid: [{vx0}:{vx1}, {vy0}:{vy1}, {vz0}:{vz1}]<br>Total: [{hx0}:{hx1}, {hy0}:{hy1}, {hz0}:{hz1}]'
                ))

    fig.update_layout(
        scene=dict(xaxis=dict(title='X'), yaxis=dict(title='Y'), zaxis=dict(title='Z'), aspectmode='data'),
        title=f"Single-Device Blocking (I={I}, Halo={halo})",
        margin=dict(l=0, r=0, b=0, t=40), height=600
    )
    fig.show()

In [None]:
# Layouts
slider_layout = Layout(width='95%')

# Widgets
w_I = IntSlider(min=1, max=10, step=1, value=1, description="I (Inner)", style=style, layout=slider_layout)

w_Sx = IntSlider(min=4, max=100, step=4, value=40, description="Sx", style=style, layout=slider_layout)
w_Sy = IntSlider(min=4, max=100, step=4, value=40, description="Sy", style=style, layout=slider_layout)
w_Sz = IntSlider(min=4, max=100, step=4, value=40, description="Sz", style=style, layout=slider_layout)

w_nx = IntSlider(min=1, max=5, step=1, value=2, description="nx", style=style, layout=slider_layout)
w_ny = IntSlider(min=1, max=5, step=1, value=2, description="ny", style=style, layout=slider_layout)
w_nz = IntSlider(min=1, max=5, step=1, value=1, description="nz", style=style, layout=slider_layout)

# UI Structure
ui = VBox([
    VBox([
        HTML("<b>Temporal Blocking Parameters:</b>"),
        w_I,
    ]),
    HBox([
        VBox([
            HTML("<b>Valid Region S:</b>"),
            w_Sx, w_Sy, w_Sz
        ], layout=Layout(width='50%')),
        VBox([
            HTML("<b>Block Count n:</b>"),
            w_nx, w_ny, w_nz
        ], layout=Layout(width='50%'))
    ])
])

out = interactive_output(plot_single_dev_blocking, {
    'I': w_I, 
    'Sx': w_Sx, 'Sy': w_Sy, 'Sz': w_Sz, 
    'nx': w_nx, 'ny': w_ny, 'nz': w_nz
})

display(ui, out)

## Multi-Device Blocking Visualization

Visualization of the domain decomposition strategy for the `multi_dev_main` application.

Adjust the sliders to see how the global domain is split across multiple GPUs (Devices).

The visualization shows the bounding box of each device's sub-domain in the global lattice coordinate system.
- **Dev Dim**: Number of devices in X, Y, Z directions.
- **Per-Device Domain**: The local domain size on each GPU, which corresponds to `CUDA Block Dim * CUDA Grid Dim`.
- **Global Domain**: `Per-Device Domain * Dev Dim`

In [None]:
def plot_multi_dev_blocking(nx, ny, nz, Bx, By, Bz):
    # Dimensions
    n = np.array([nx, ny, nz])
    B = np.array([Bx, By, Bz])
    N = B * n  # Global domain size

    # print(f"Device Dim: {n}")
    # print(f"Per-Device Domain (Block * Grid): {B}")
    print(f"Global Domain: {N}")
    print()

    fig = go.Figure()

    # Generate boxes for each device
    for x in range(nx):
        for y in range(ny):
            for z in range(nz):
                gpu_id = x + y * nx + z * nx * ny

                # Start/End points
                x0, y0, z0 = x * B[0], y * B[1], z * B[2]
                x1, y1, z1 = x0 + B[0], y0 + B[1], z0 + B[2]

                # Edges
                wx, wy, wz = get_cube_wireframe(x0, y0, z0, x1, y1, z1)
                fig.add_trace(go.Scatter3d(x=wx, y=wy, z=wz, mode="lines", name=f"GPU {gpu_id}", line=dict(width=4), showlegend=False))

                # Faces (semi-transparent)
                mx, my, mz, i, j, k = get_cube_mesh(x0, y0, z0, x1, y1, z1)
                fig.add_trace(go.Mesh3d(x=mx, y=my, z=mz, i=i, j=j, k=k, opacity=0.1, color="blue", showscale=False, hoverinfo="text", text=f"Device ({x}, {y}, {z})<br>Range: [{x0}:{x1}, {y0}:{y1}, {z0}:{z1}]"))

                # Center Label
                fig.add_trace(go.Scatter3d(x=[(x0 + x1) / 2], y=[(y0 + y1) / 2], z=[(z0 + z1) / 2], mode="text", text=[f"GPU {gpu_id}"], showlegend=False, textfont=dict(color="black", size=10)))

    fig.update_layout(scene=dict(xaxis=dict(title="X (Lattice Units)"), yaxis=dict(title="Y (Lattice Units)"), zaxis=dict(title="Z (Lattice Units)"), aspectmode="data"), title=f"Multi-Device Decomposition (Total: {nx * ny * nz} GPUs)", margin=dict(l=0, r=0, b=0, t=40), height=800)
    fig.show()

In [None]:
# Widgets
w_nx = IntSlider(min=1, max=8, step=1, value=2, description='nx', style=style, layout=slider_layout)
w_ny = IntSlider(min=1, max=8, step=1, value=2, description='ny', style=style, layout=slider_layout)
w_nz = IntSlider(min=1, max=8, step=1, value=1, description='nz', style=style, layout=slider_layout)

w_Bx = IntSlider(min=32, max=1024, step=32, value=256, description='Bx', style=style, layout=slider_layout)
w_By = IntSlider(min=32, max=1024, step=32, value=256, description='By', style=style, layout=slider_layout)
w_Bz = IntSlider(min=32, max=1024, step=32, value=256, description='Bz', style=style, layout=slider_layout)

ui = HBox([
    VBox([
        HTML("<b>Device Dimensions (Grid of GPUs):</b>"),
        w_nx, w_ny, w_nz
    ], layout=Layout(width='50%')),
    VBox([
        HTML("<b>Per-Device Domain Size:</b>"),
        w_Bx, w_By, w_Bz
    ], layout=Layout(width='50%'))
])

out = interactive_output(plot_multi_dev_blocking, {
    'nx': w_nx, 'ny': w_ny, 'nz': w_nz,
    'Bx': w_Bx, 'By': w_By, 'Bz': w_Bz
})

display(ui, out)