In [10]:

import torchhd
import torch

def cartesian_bind_tensor(tensors: list[torch.Tensor]) -> torch.Tensor:
    """
    Fully vectorized: builds an index‐grid via torch.cartesian_prod,
    gathers each set, stacks along new dim=1 [P,D], and multibinds.
    Returns [N_prod, D].
    """
    if not tensors:
        raise ValueError("Need at least one set")
    # 1) get the shapes and build the cartesian product of indices
    shapes = [s.shape[0] for s in tensors]
    grids  = torch.cartesian_prod(
        *[torch.arange(n, device=tensors[0].device) for n in shapes]
    )  # → [N_prod, K]

    # 2) for each position k, gather hypervectors
    #    hv_k: [N_prod, D]
    hv_list = [tensors[k][grids[:, k]] for k in range(len(tensors))]

    # 3) stack them into [N_prod, K, D]
    stacked = torch.stack(hv_list, dim=1)

    # 4) multibind along dim=1 → [N_prod, D]
    return torchhd.multibind(stacked)

def cartesian_bind_tensor_2(list_tensors):
    """
    Args:
        list_tensors: List of P hypervector tensors, each [Nₚ, D].
    Returns:
        keys:  List of length B=∏Nₚ, each a tuple (i₁,…,iₚ).
        out:   Tensor of shape [B, D], where
               out[b] = torchhd.multibind([list_tensors[p][iₚ₍b₎] for p in 0..P-1]).
    """
    # Number of domains and feature-dim
    P    = len(list_tensors)
    # 1) build a meshgrid of indices [N₁,…,Nₚ] → P tensors each [N₁…Nₚ]
    ranges = [torch.arange(t.shape[0], device=t.device) for t in list_tensors]
    grids  = torch.meshgrid(*ranges, indexing='ij')

    # 2) flatten each grid to shape [B]
    idxs = [g.reshape(-1) for g in grids]
    B    = idxs[0].numel()

    # 3) gather each domain into [B, D]
    gathered = [list_tensors[p][idxs[p]] for p in range(P)]  # list of [B,D]

    # 4) stack into [B, P, D] and bind across P → [B, D]
    stacked = torch.stack(gathered, dim=1)  # shape [B, P, D]
    # → [B, D]
    return torchhd.multibind(stacked)

import torch
import torchhd
from typing import List

def cartesian_bind_tensor_3(tensors: List[torch.Tensor]) -> torch.Tensor:
    """
    Given a list of P hypervector sets [N_p, D],
    returns a [∏N_p, D] tensor where each row is the bind
    of one choice from each set.
    """
    # Number of domains (P) and feature-dim (D)
    P = len(tensors)
    if P == 0:
        raise ValueError("Need at least one tensor")
    D = tensors[0].shape[1]
    device = tensors[0].device

    # 1) Build all combinations of indices: shape [B, P]
    ranges = [torch.arange(t.shape[0], device=device) for t in tensors]
    idx_grid = torch.cartesian_prod(*ranges)  # → [B, P]

    # 2) Gather the corresponding hypervectors: list of [B, D]
    gathered = [
        tensors[p][idx_grid[:, p]]
        for p in range(P)
    ]

    # 3) Stack into [B, P, D] and multibind → [B, D]
    stacked = torch.stack(gathered, dim=1)
    return torchhd.multibind(stacked)

## Prepare Data

We’ll fix `D=256` and benchmark 2-way Cartesian binds for N=10, 50, 100.

In [15]:
# Cell 2: generate datasets

torch.manual_seed(0)
D = 3000
sizes = [100, 500, 1000]

data_sets = {
    N: [
        torchhd.random(N, D, vsa="MAP", device="cpu"),
        torchhd.random(N, D, vsa="MAP", device="cpu")
    ]
    for N in sizes
}

In [None]:
# Cell 3: benchmarking loop

results = {}

for N, sets in data_sets.items():
    print(f"\n--- N = {N} ---")
    # meshgrid version
    time_mesh = get_ipython().run_line_magic(
        'timeit', '-o cartesian_bind_tensor(sets)'
    )
    # broadcast version
    time_bcast = get_ipython().run_line_magic(
        'timeit', '-o cartesian_bind_tensor_2(sets)'
    )
    results[N] = {
        'mesh': time_mesh.average,
        'bcast': time_bcast.average
    }


--- N = 100 ---
26.4 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
25.8 ms ± 932 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

--- N = 500 ---
1.93 s ± 35.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.2 s ± 265 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

--- N = 1000 ---
