In [None]:
import numpy as np
import math
import itertools as it
import colorsys

# --- Bokeh Imports ---
from bokeh.plotting import figure, show
from bokeh.models import (
    ColumnDataSource, 
    HoverTool, 
    CategoricalColorMapper, 
    LinearColorMapper,
    ColorBar,
    CustomJS,
    RadioGroup,
    Range1d,
    Div,
    Plot,
    Rect,
    Text
)
from bokeh.layouts import column, row
from bokeh.io import output_notebook, show


class Layout:
    """
    Analyzes shared memory access patterns for various tensor layout permutations.

    This class models the process of mapping a high-level, N-dimensional
    tensor access (defined by shape `S` and strides `D`) to the underlying
    physical shared memory (smem) addresses.

    It computes all permutations of the tensor strides, calculates the
    resulting smem word addresses for every element, and analyzes the
    resulting bank conflicts for each layout, both with and without
    a swizzle operation.

    Args:
        m (int): The dimensionality of the tensor (e.g., 2 for a 2D tensor).
        S (np.ndarray): The shape of the tensor (1D array of size `m`).
        D (np.ndarray): The base strides of the tensor (1D array of size `m`).
        b_bits (int): The number of bits to use in the swizzle operation.
        m_base (int): The base bit offset for the swizzle operation.
        s_shift (int): The bit-shift distance for the swizzle operation.
        warp_size (int): The number of threads in a warp (e.g., 32).
        smem_word_width (int): The width of a shared memory word in bits (e.g., 32 or 64).
        element_vector_size (int): The number of scalar elements per vector (e.g., 1, 2, 4).
        element_type_width (int): The width of a single scalar element in bits (e.g., 16 for fp16).
        address_atom_width (int): The smallest addressable unit in bits (e.g., 4 for 4-bit atoms).
        n_banks (int): The number of shared memory banks (e.g., 32).

    Attributes:
        m (int): Tensor dimensionality.
        S (np.ndarray): Tensor shape.
        D (np.ndarray): Tensor strides.
        N_perms (int): Number of permutations (m!).
        perms (np.ndarray): Array of shape (N_perms, m) holding all permutations.
        D_permuted (np.ndarray): Array of shape (N_perms, m) holding all permuted stride vectors.
        N_elems (int): Total number of elements in the tensor (prod(S)).
        flat_domain (np.ndarray): 1D array from 0 to N_elems-1.
        shape_strides (np.ndarray): Column-major strides for the shape S.
        flat_to_nd_indices (np.ndarray): Map of shape (1, N_elems, m) to convert flat index to N-D index.
        vector_addresses (np.ndarray): Array of shape (N_perms, N_elems) holding the logical
                                          vector address for each element in each layout.
        smem_address_map (np.ndarray): Array of shape (N_perms, N_elems, max_smem_addrs)
                                          mapping each element to the set of smem words it hits.
        swizzle_mask (int): The bitmask covering all bits involved in the swizzle.
        swizzled_vector_addresses (np.ndarray): `vector_addresses` after applying the swizzle.
        swizzled_smem_address_map (np.ndarray): `smem_address_map` for swizzled addresses.
        layout_bank_hits (np.ndarray): Array of shape (N_perms, N_warps, n_banks)
                                          storing bank hit counts for each non-swizzled layout.
        swizzled_bank_hits (np.ndarray): Bank hit counts for swizzled layouts.
    """
    def __init__(self, m:int, S:np.ndarray, D:np.ndarray, b_bits:int, m_base:int, 
                 s_shift:int, warp_size:int, smem_word_width:int, 
                 element_vector_size:int, element_type_width:int, address_atom_width:int, n_banks:int):
        
        # --- 1. Parameter Validation ---
        assert m <= 4, "Dimensionality m > 4 not supported (too many permutations)"
        assert S.shape == (m,)
        assert D.shape == (m,)
        assert b_bits >= 0
        assert m_base >= 0
        assert abs(s_shift) >= b_bits, "Swizzle shift must be at least b_bits"
        assert np.prod(S).item() <= 1024, "Total elements must be <= 64"
        assert np.prod(D).item() <= 1024, "Total strides product must be <= 64"
        assert (np.prod(S).item() % warp_size == 0), "Total elements must be a multiple of warp_size"

        self.m = m
        self.S = S
        self.D = D
        self.b_bits = b_bits
        self.m_base = m_base
        self.s_shift = s_shift

        # --- 2. Permutation Generation ---
        self.N_perms = self._factorial(m)
        self.perms = self._generate_perms(m)
        # Apply permutations to the strides D, not shape S
        self.D_permuted = self._apply_perms(self.D, self.perms) 

        # --- 3. N-D Index to Linear Address Mapping ---
        self.N_elems = np.prod(S).item()
        self.flat_domain = np.arange(self.N_elems)
        
        # Calculate column-major strides for the shape S
        shape_strides = np.ones((m,), dtype=int)
        for i in reversed(range(m - 1)):
            shape_strides[i] = S[i + 1] * shape_strides[i + 1]
        self.shape_strides = shape_strides
        
        # Use broadcasting to convert flat_domain into N-D indices
        # Shape: (1, N_elems, m)
        self.flat_to_nd_indices = (self.flat_domain.reshape(1, self.N_elems, 1) // self.shape_strides.reshape(1, 1, self.m)) \
                                  % (self.S.reshape(1, 1, self.m))
        
        # Calculate logical vector address for each element under each layout permutation
        # This is a dot product: sum(D_permuted * nd_index)
        # Shape: (N_perms, N_elems)
        self.vector_addresses = np.sum(self.D_permuted.reshape(self.N_perms, 1, self.m) * self.flat_to_nd_indices, axis=-1)
        
        # --- 4. Logical Address to Physical smem Word Mapping ---
        self.smem_address_map = self._layout_to_smem_addr(
            self.vector_addresses, 
            element_vector_size, element_type_width, 
            smem_word_width, address_atom_width
        )

        # --- 5. Apply Swizzle ---
        self.swizzle_mask, self.swizzled_vector_addresses = self._apply_swizzle(
            self.vector_addresses, b_bits, m_base, s_shift
        )
        
        self.swizzled_smem_address_map = self._layout_to_smem_addr(
            self.swizzled_vector_addresses, 
            element_vector_size, element_type_width, 
            smem_word_width, address_atom_width
        )

        # --- 6. Bank Conflict Analysis ---
        self.layout_bank_hits = self._get_bank_hits(
            self.smem_address_map, warp_size, n_banks, self.N_elems, self.N_perms
        )
        
        self.swizzled_bank_hits = self._get_bank_hits(
            self.swizzled_smem_address_map, warp_size, n_banks, self.N_elems, self.N_perms
        )

    def _factorial(self, n: int) -> int:
        """Calculates factorial n!."""
        return math.factorial(n)

    def _generate_perms(self, n: int) -> np.ndarray:
        """Generates all permutations for a list of n indices."""
        return np.array(list(it.permutations(range(n))), dtype=int)

    def _apply_perms(self, X: np.ndarray, perms: np.ndarray) -> np.ndarray:
        """
        Applies all permutations to a vector X.
        
        Args:
            X (np.ndarray): 1D array of size (m,).
            perms (np.ndarray): 2D array of permutations of shape (N_perms, m).
            
        Returns:
            np.ndarray: A 2D array of shape (N_perms, m) where each row
                        is a permuted version of X.
        """
        return X[perms]

    def _layout_to_smem_addr(self, layout_addresses: np.ndarray, 
                             vector_size: int, element_type_width: int, 
                             smem_word_width: int, address_atom_width: int) -> np.ndarray:
        """
        Maps a 2D array of logical vector addresses to their physical smem word sets.

        Args:
            layout_addresses (np.ndarray): 2D array (N_perms, N_elems) of logical addresses.
            ... (other args): Passed to helper.

        Returns:
            np.ndarray: 3D array (N_perms, N_elems, max_smem_addrs) mapping
                        each element to its set of smem words. Padded with -1.
        """
        n_perms, n_elems = layout_addresses.shape
        S, ew, sw, aw = vector_size, element_type_width, smem_word_width, address_atom_width

        # --- Calculate theoretical max smem words per vector (for padding) ---
        N_bits_per_vector = ew * S
        # Worst-case span for N bits across 'aw'-bit atoms
        max_n_addr_atoms = ((N_bits_per_vector - 2) // aw) + 2 if N_bits_per_vector > 1 else 1
        # Atoms per smem word
        atoms_per_smem_word = sw // aw
        # Worst-case span for 'max_n_addr_atoms' across 'atoms_per_smem_word' words
        max_n_smem_addrs = ((max_n_addr_atoms - 2) // atoms_per_smem_word) + 2 if max_n_addr_atoms > 1 else 1
        
        # Pre-fill the output map with -1 (padding)
        layout_to_smem_addr_map = -np.ones((n_perms, n_elems, max_n_smem_addrs), dtype=int)
        
        # Loop is necessary because np.unique returns ragged (variable-size) arrays
        for i in range(n_perms):
            for j in range(n_elems):
                vec_addr = layout_addresses[i, j]
                smem_addr_set = self._abs_vec_addr_to_smem_addr_set(
                    vec_addr, S, ew, sw, aw
                )
                
                n_actual = smem_addr_set.size
                if n_actual > max_n_smem_addrs:
                    # This should not happen if the bound is correct, but good to check
                    raise ValueError(f"Address {vec_addr} hit {n_actual} words, but bound was {max_n_smem_addrs}")
                    
                layout_to_smem_addr_map[i, j, 0:n_actual] = smem_addr_set

        return layout_to_smem_addr_map

    def _abs_vec_addr_to_smem_addr_set(self, vec_addr: int, vector_size: int, 
                                       element_type_width: int, smem_word_width: int, 
                                       address_atom_width: int) -> np.ndarray:
        """
        Calculates the unique set of smem words hit by a single vector access.

        Args:
            vec_addr (int): The logical vector address.
            ... (other args): Type and hardware geometry.

        Returns:
            np.ndarray: A 1D array of the unique smem word addresses hit.
        """
        v, S, ew, sw, aw = vec_addr, vector_size, element_type_width, smem_word_width, address_atom_width

        # 1. Get scalar element indices
        Q = np.arange(S)
        # 2. Get absolute scalar element addresses
        abs_scalar_addr = (v * S) + Q
        # 3. Get bit indices (column vector)
        EW = np.arange(ew).reshape(ew, 1)
        # 4. Broadcast to get all bit addresses
        bit_addrs = EW + (ew * abs_scalar_addr)
        # 5. Convert bit addresses to atom addresses
        atom_addrs = bit_addrs // aw
        # 6. Convert atom addresses to smem word addresses
        smem_addrs = atom_addrs // (sw // aw)

        # 7. Find the unique set of smem words hit
        smem_addr_set_actual = np.unique(smem_addrs)
        return smem_addr_set_actual

    def _apply_swizzle(self, X: np.ndarray, b_bits: int, m_base: int, s_shift: int) -> (int, np.ndarray):
        """
        Applies a bitwise XOR swizzle to the layout addresses.

        Args:
            X (np.ndarray): 2D array of layout addresses (N_perms, N_elems).
            ... (other args): Swizzle parameters.

        Returns:
            tuple: (swizzle_mask, swizzled_X)
                swizzle_mask (int): The mask covering all bits involved.
                swizzled_X (np.ndarray): The swizzled address array.
        """
        # Create masks
        base_mask = (1 << b_bits) - 1
        
        # Source/Destination shifts based on s_shift direction
        src_shift = m_base + max(0, s_shift)
        dst_shift = m_base - min(0, s_shift)
        
        src_mask = base_mask << src_shift
        dst_mask = base_mask << dst_shift
        swizzle_mask = src_mask | dst_mask

        # Isolate the bits to move
        bits_to_move = X & src_mask
        
        # Apply XOR swizzle
        if s_shift > 0:
            # Shift bits from src down to dst
            swizzled_X = X ^ (bits_to_move >> abs(s_shift))
        else:
            # Shift bits from src up to dst
            swizzled_X = X ^ (bits_to_move << abs(s_shift))

        return swizzle_mask, swizzled_X

    def _get_bank_hits(self, layout_to_addr: np.ndarray, warp_size: int, 
                       n_banks: int, n_elems: int, n_perms: int) -> np.ndarray:
        """
        Calculates the bank hit statistics for all layouts.

        Args:
            layout_to_addr (np.ndarray): 3D map of (perm, elem, smem_addrs).
            ... (other args): Geometry parameters.

        Returns:
            np.ndarray: 3D array (N_perms, N_warps, n_banks) containing the
                        total number of accesses to each bank per warp.
                        The max of this array is the conflict degree.
        """
        n_warps = n_elems // warp_size
        n_addrs_per_thread = layout_to_addr.shape[-1]

        # Reshape to (perm, warp, thread, addrs)
        addr_map_warped = layout_to_addr.reshape(n_perms, n_warps, warp_size, n_addrs_per_thread)

        # Accumulator for hit counts
        bank_hits = np.zeros((n_perms, n_warps, n_banks), dtype=int)

        # We must loop over perms and warps, but can vectorize the thread/address dimensions
        for p in range(n_perms):
            for w in range(n_warps):
                # Get all addresses for this warp: (warp_size, n_addrs)
                warp_addrs = addr_map_warped[p, w]

                # Filter out all padding (-1)
                valid_addrs = warp_addrs[warp_addrs != -1]

                if valid_addrs.size == 0:
                    continue # No addresses for this warp

                # Calculate banks for all valid addresses
                banks = valid_addrs % n_banks

                # Count occurrences of each bank
                # This gives the total number of accesses to each bank
                unique_banks, counts = np.unique(banks, return_counts=True)

                # Store these counts in the accumulator
                bank_hits[p, w, unique_banks] = counts

        return bank_hits

In [11]:
m = 2 
S = np.array([128,4])
D = np.array([4,1])
b_bits = 3 
m_base = 1
s_shift = 4
warp_size = 32
smem_word_width = 32
element_vector_size = 4 
element_type_width = 32
address_atom_width = 8 
n_banks = 32

L = Layout(m, S,D, b_bits, m_base, s_shift, warp_size,
           smem_word_width, element_vector_size, element_type_width,
           address_atom_width, n_banks)


In [12]:
print(L.swizzle_mask)
L.swizzled_bank_hits

238


array([[[4, 4, 4, ..., 4, 4, 4],
        [4, 4, 4, ..., 4, 4, 4],
        [4, 4, 4, ..., 4, 4, 4],
        ...,
        [4, 4, 4, ..., 4, 4, 4],
        [4, 4, 4, ..., 4, 4, 4],
        [4, 4, 4, ..., 4, 4, 4]],

       [[4, 4, 4, ..., 4, 4, 4],
        [4, 4, 4, ..., 4, 4, 4],
        [4, 4, 4, ..., 4, 4, 4],
        ...,
        [4, 4, 4, ..., 4, 4, 4],
        [5, 5, 5, ..., 3, 3, 3],
        [5, 5, 5, ..., 3, 3, 3]]], shape=(2, 16, 32))

In [16]:
class LayoutVisualizer(Layout):
    """
    Interactive visualizer for Layout bank conflicts and address maps
    using Bokeh.
    
    Inherits from Layout to run all calculations, then provides a .display()
    method to launch an interactive dashboard in a new browser tab.
    """
    
    def __init__(self, *args, **kwargs):
        # Run the full Layout calculation
        super().__init__(*args, **kwargs)
        
        # Store key geometry for plotting
        self.n_banks = kwargs.get('n_banks', 32)
        self.warp_size = kwargs.get('warp_size', 32)
        self.n_warps = self.N_elems // self.warp_size
        
        # --- Visualization Parameters ---
        self.GRID_WIDTH = 64  # How many addresses wide to make the 2D tile
        self.DEFAULT_COLOR = "#f0f0f0"
        
        # --- Generate Data for Bokeh ---
        print("Generating thread colors...")
        self._generate_thread_colors()
        
        print("Preparing data sources for all permutations...")
        self._prepare_all_data_sources()
        print("Data preparation complete.")

    def _rgb_to_hex(self, r, g, b):
        """Converts RGB (0-1) to hex string."""
        return '#%02x%02x%02x' % (int(r*255), int(g*255), int(b*255))

    def _generate_thread_colors(self):
        """
        Generates a unique color for each thread (element).
        Hue = Warp ID
        Saturation/Lightness = Lane ID
        """
        self.thread_colors_hex = []
        hues = np.linspace(0, 1, self.n_warps, endpoint=False)
        
        # Create a 4x8 grid for Saturation and Value (Lightness)
        sats = np.linspace(0.6, 1.0, 4)
        vals = np.linspace(0.5, 1.0, 8)
        
        for w in range(self.n_warps):
            for l in range(self.warp_size):
                h = hues[w]
                # Map lane ID to the 4x8 grid
                s = sats[l % 4]
                v = vals[l // 4] 
                
                r, g, b = colorsys.hsv_to_rgb(h, s, v)
                self.thread_colors_hex.append(self._rgb_to_hex(r, g, b))

    def _prepare_all_data_sources(self):
        """Pre-computes and stores all data sources for all permutations."""
        
        # Find the global max address across all permutations to set the grid size
        self.max_global_addr = max(
            np.max(self.smem_address_map), 
            np.max(self.swizzled_smem_address_map)
        )
        
        self.N_BOXES = self.max_global_addr + 1
        self.N_ROWS = (self.N_BOXES + self.GRID_WIDTH - 1) // self.GRID_WIDTH
        
        # --- Create the base grid (same for all plots) ---
        self.base_grid_data = {
            'x': [], 'y': [], 'address': [], 'bank': [], 'addr_str': []
        }
        for i in range(self.N_BOXES):
            self.base_grid_data['x'].append(i % self.GRID_WIDTH)
            # Y-axis is inverted so 0 is at the top-left
            self.base_grid_data['y'].append(self.N_ROWS - 1 - (i // self.GRID_WIDTH))
            self.base_grid_data['address'].append(i)
            self.base_grid_data['bank'].append(i % self.n_banks)
            self.base_grid_data['addr_str'].append(str(i))
            
        # Now, prepare the permutation-specific data
        self.all_data = {}
        for p in range(self.N_perms):
            self.all_data[p] = self._prepare_data_for_perm(p)
            
    def _prepare_data_for_perm(self, p: int):
        """
        Processes the raw data for a single permutation into
        Bokeh-ready ColumnDataSources.
        """
        
        # --- 1. Address Map Data ---
        # Start with default colors/info
        color_map = [self.DEFAULT_COLOR] * self.N_BOXES
        thread_info_map = ["N/A"] * self.N_BOXES
        
        swizzled_color_map = [self.DEFAULT_COLOR] * self.N_BOXES
        swizzled_thread_info_map = ["N/A"] * self.N_BOXES
        
        for e in range(self.N_elems): # e = element/thread ID
            warp_id = e // self.warp_size
            lane_id = e % self.warp_size
            thread_color = self.thread_colors_hex[e]
            info = f"Thread {e} (W{warp_id}, L{lane_id})"
            
            # Process non-swizzled map
            for addr in self.smem_address_map[p, e]:
                if addr == -1: continue
                color_map[addr] = thread_color
                thread_info_map[addr] = info

            # Process swizzled map
            for addr in self.swizzled_smem_address_map[p, e]:
                if addr == -1: continue
                swizzled_color_map[addr] = thread_color
                swizzled_thread_info_map[addr] = info
        
        # Combine base grid with permutation data
        map_data = self.base_grid_data.copy()
        map_data['color'] = color_map
        map_data['thread_info'] = thread_info_map
        
        swizzled_map_data = self.base_grid_data.copy()
        swizzled_map_data['color'] = swizzled_color_map
        swizzled_map_data['thread_info'] = swizzled_thread_info_map

        # --- 2. Bank Hit Data (Heatmap) ---
        hits_data = {'warp': [], 'bank': [], 'hits': []}
        swizzled_hits_data = {'warp': [], 'bank': [], 'hits': []}
        
        hits_slice = self.layout_bank_hits[p]
        swizzled_hits_slice = self.swizzled_bank_hits[p]
        
        for w in range(self.n_warps):
            for b in range(self.n_banks):
                hits_data['warp'].append(str(w)) # Use strings for categorical range
                hits_data['bank'].append(str(b))
                hits_data['hits'].append(hits_slice[w, b])
                
                swizzled_hits_data['warp'].append(str(w))
                swizzled_hits_data['bank'].append(str(b))
                swizzled_hits_data['hits'].append(swizzled_hits_slice[w, b])
        
        return (
            ColumnDataSource(map_data), 
            ColumnDataSource(swizzled_map_data),
            ColumnDataSource(hits_data),
            ColumnDataSource(swizzled_hits_data),
            np.max(hits_slice), # max conflict
            np.max(swizzled_hits_slice) # swizzled max conflict
        )
            
    def _create_color_legend(self):
        """Creates a Bokeh plot to serve as a 2D color legend for threads."""
        
        data = {'x': [], 'y': [], 'color': []}
        for w in range(self.n_warps):
            for l in range(self.warp_size):
                data['x'].append(l)
                data['y'].append(w)
                data['color'].append(self.thread_colors_hex[w * self.warp_size + l])
        
        source = ColumnDataSource(data)
        
        p = figure(
            width=900,
            height=250,
            title="Thread Color Legend (Y-axis: Warp ID, X-axis: Lane ID)",
            x_range=Range1d(-0.5, self.warp_size - 0.5),
            y_range=Range1d(-0.5, self.n_warps - 0.5),
            tools="hover",
            tooltips=[("Warp", "@y"), ("Lane", "@x")]
        )
        p.rect(x='x', y='y', width=1, height=1, color='color', source=source)
        p.xaxis.axis_label = "Lane ID"
        p.yaxis.axis_label = "Warp ID"
        p.xaxis.ticker = list(range(self.warp_size))
        p.yaxis.ticker = list(range(self.n_warps))
        return p

    def display(self):
        """
        Renders the interactive Bokeh visualization dashboard.
        
        Call this in a Jupyter notebook cell (or it will open in a new tab).
        """
        
        # --- 1. Create Plots ---
        TOOLS = "pan,wheel_zoom,box_zoom,reset,save"
        
        addr_map_hover = HoverTool(tooltips=[
            ("Address", "@address"),
            ("Bank", "@bank"),
            ("Access", "@thread_info"),
        ])
        
        # Create shared ranges for the 2D grid plots
        x_range = Range1d(-0.5, self.GRID_WIDTH - 0.5)
        y_range = Range1d(-0.5, self.N_ROWS - 0.5)

        # Plot 1: Non-Swizzled Address Map
        p_map = figure(
            height=600, 
            width=900, 
            title="Non-Swizzled Address Map", 
            tools=[TOOLS, addr_map_hover],
            x_range=x_range,
            y_range=y_range,
            match_aspect=True # Keep boxes square
        )
        p_map.rect(x='x', y='y', width=1, height=1, color='color', source=self.all_data[0][0])
        p_map.text(x='x', y='y', text='addr_str', source=self.all_data[0][0],
                   text_font_size='7px', text_align='center', text_baseline='middle')
        p_map.xaxis.axis_label = f"Address (Tiled in rows of {self.GRID_WIDTH})"
        p_map.yaxis.visible = False

        # Plot 2: Swizzled Address Map
        p_map_swizzled = figure(
            height=600, 
            width=900, 
            title="Swizzled Address Map", 
            tools=[TOOLS, addr_map_hover],
            x_range=x_range, # Link X/Y axes
            y_range=y_range,
            match_aspect=True
        )
        p_map_swizzled.rect(x='x', y='y', width=1, height=1, color='color', source=self.all_data[0][1])
        p_map_swizzled.text(x='x', y='y', text='addr_str', source=self.all_data[0][1],
                            text_font_size='7px', text_align='center', text_baseline='middle')
        p_map_swizzled.xaxis.axis_label = f"Address (Tiled in rows of {self.GRID_WIDTH})"
        p_map_swizzled.yaxis.visible = False
        
        # --- Bank Hit Plots ---
        hits_hover = HoverTool(tooltips=[
            ("Warp", "@warp"),
            ("Bank", "@bank"),
            ("Hits", "@hits"),
        ])
        
        bank_labels = [str(i) for i in range(self.n_banks)]
        warp_labels = [str(i) for i in range(self.n_warps)]

        # Use Div for titles
        plot_width = 900
        title_hits = Div(
            text=f"<h3>Non-Swizzled Bank Hits (Max Conflict: {self.all_data[0][4]}-way)</h3>",
            width=plot_width
        )
        title_hits_swizzled = Div(
            text=f"<h3>Swizzled Bank Hits (Max Conflict: {self.all_data[0][5]}-way)</h3>",
            width=plot_width
        )

        # Plot 3: Non-Swizzled Bank Hits
        max_c = max(1, self.all_data[0][4])
        color_mapper = LinearColorMapper(palette="Reds9", low=0, high=max_c)
        p_hits = figure(
            height=400, 
            width=plot_width, 
            x_range=bank_labels, 
            y_range=warp_labels,
            tools=[hits_hover, "save"],
            x_axis_location="above",
            title="Non-Swizzled Bank Hits" # Static title
        )
        p_hits.rect(x='bank', y='warp', width=1, height=1, source=self.all_data[0][2], 
                    fill_color={'field': 'hits', 'transform': color_mapper}, line_color=None)
        p_hits.xaxis.axis_label = "Bank ID"
        p_hits.yaxis.axis_label = "Warp ID"
        color_bar = ColorBar(color_mapper=color_mapper, label_standoff=12, location=(0,0))
        p_hits.add_layout(color_bar, 'right')

        # Plot 4: Swizzled Bank Hits
        max_c_swizzled = max(1, self.all_data[0][5])
        color_mapper_swizzled = LinearColorMapper(palette="Reds9", low=0, high=max_c_swizzled)
        p_hits_swizzled = figure(
            height=400, 
            width=plot_width, 
            x_range=p_hits.x_range, 
            y_range=p_hits.y_range,
            tools=[hits_hover, "save"],
            x_axis_location="above",
            title="Swizzled Bank Hits" # Static title
        )
        p_hits_swizzled.rect(x='bank', y='warp', width=1, height=1, source=self.all_data[0][3],
                            fill_color={'field': 'hits', 'transform': color_mapper_swizzled}, line_color=None)
        p_hits_swizzled.xaxis.axis_label = "Bank ID"
        p_hits_swizzled.yaxis.axis_label = "Warp ID"
        color_bar_swizzled = ColorBar(color_mapper=color_mapper_swizzled, label_standoff=12, location=(0,0))
        p_hits_swizzled.add_layout(color_bar_swizzled, 'right')
        
        # --- 2. Create Controls & Callbacks ---
        
        radio_group = RadioGroup(
            labels=[f"Permutation {i}" for i in range(self.N_perms)], 
            active=0
        )
        
        callback_args = dict(
            all_data=self.all_data,
            map_source=self.all_data[0][0],
            map_swizzled_source=self.all_data[0][1],
            hits_source=self.all_data[0][2],
            hits_swizzled_source=self.all_data[0][3],
            color_mapper=color_mapper,
            color_mapper_swizzled=color_mapper_swizzled,
            title_hits=title_hits,
            title_hits_swizzled=title_hits_swizzled
        )
        
        callback_code = """
            const perm_index = cb_obj.active;
            const new_data = all_data[perm_index];
            
            // Update data sources
            map_source.data = new_data[0].data;
            map_swizzled_source.data = new_data[1].data;
            hits_source.data = new_data[2].data;
            hits_swizzled_source.data = new_data[3].data;
            
            // Update color mappers
            const max_hits = new_data[4];
            const max_hits_swizzled = new_data[5];
            
            color_mapper.high = Math.max(1, max_hits);
            color_mapper_swizzled.high = Math.max(1, max_hits_swizzled);
            
            // Update titles
            title_hits.text = `<h3>Non-Swizzled Bank Hits (Max Conflict: ${max_hits}-way)</h3>`;
            title_hits_swizzled.text = `<h3>Swizzled Bank Hits (Max Conflict: ${max_hits_swizzled}-way)</h3>`;
        """
        
        radio_group.js_on_change('active', CustomJS(args=callback_args, code=callback_code))
        
        # --- 3. Layout Dashboard ---
        legend = self._create_color_legend()
        
        layout = column(
            radio_group,
            legend,
            row(p_map, p_map_swizzled),
            row(
                column(title_hits, p_hits),
                column(title_hits_swizzled, p_hits_swizzled)
            )
        )
        
        # show() will open in a new browser tab or display in the notebook
        show(layout)



In [32]:
m = 2 
S = np.array([8,8])
D = np.array([1,8])
b_bits = 2
m_base = 1
s_shift = 3
warp_size = 32
smem_word_width = 32
element_vector_size = 2
element_type_width = 32
address_atom_width = 8 
n_banks = 32

L = LayoutVisualizer(m, S,D, b_bits, m_base, s_shift, warp_size,
           smem_word_width, element_vector_size, element_type_width,
           address_atom_width, n_banks)


Generating thread colors...
Preparing data sources for all permutations...
Data preparation complete.


In [33]:
L.display()