In [13]:
from torchhd import VTBTensor

from src.encoding.the_types import VSAModel

"""
autograd_check.py

A small script that verifies whether torchhd.bind, torchhd.bundle, torchhd.multibind,
torchhd.multibundle, and scatter_hd all support backpropagation (i.e., produce valid
gradients under PyTorch’s autograd). We test this for each TorchHD VSA model (MAP, HRR,
BSC/Ternary, etc.) and for each of the four reduction ops.

Usage:
    python autograd_check.py
"""

import torch
from torch import Tensor
import torch.nn.functional as F
from torch_geometric.utils import scatter
import torch.optim as optim
import torchhd

# Import all available hypervector types from torchhd.tensors.*
from torchhd.tensors.map import MAPTensor       #  [oai_citation:0‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/torchhd.html?utm_source=chatgpt.com)
from torchhd.tensors.hrr import HRRTensor         #  [oai_citation:1‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/torchhd.html?utm_source=chatgpt.com)
from torchhd.tensors.bsbc import BSBCTensor       #  [oai_citation:2‡aidoczh.com](https://www.aidoczh.com/torchhd/_modules/torchhd/tensors/bsbc.html?utm_source=chatgpt.com)
from torchhd.tensors.fhrr import FHRRTensor       # (if installed)  [oai_citation:3‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/getting_started.html?utm_source=chatgpt.com)

import random

# 1) gather all torchhd tensor‐classes to test
VSA_TYPES = {
    "MAP": MAPTensor,
    "HRR": HRRTensor,
    "FHRR": FHRRTensor,
    "VTB": VTBTensor,
}


# 2) helper: wrap torchhd.bind / bundle etc. in a single place
OP_FUNCS = {
    "bind": torchhd.bind,
    "bundle": torchhd.bundle,
}

MULTI_FUNCS = {
    "multibind": torchhd.multibind,
    "multibundle": torchhd.multibundle,
}

# 3) Scatter‐HD helper (flattened) as given in the prompt
from typing import Literal
ReductionOP = Literal["bind", "bundle"]

def scatter_hd(
        src: Tensor,
        index: Tensor,
        *,
        op: ReductionOP,
        dim_size: int | None = None,
) -> Tensor:
    """
    Scatter-reduce a batch of hypervectors along dim=0 using
    either torchhd.bind or torchhd.bundle, with minimal overhead
    for MAP, BSC and HRR models.

    Args:
        src (Tensor): hypervector batch of shape [N, D, ...] where
                      N is the “items” dimension to scatter over.
        index (LongTensor): shape [N], bucket indices in [0..dim_size).
        op (Callable): either torchhd.bind or torchhd.bundle.
        dim_size (int, optional): number of output buckets.
                                   If None, uses index.max()+1.

    Returns:
        Tensor: scattered & reduced hypervectors of shape
                [dim_size, D, ...], same dtype/device as src.
    """
    # infer output size
    if dim_size is None:
        dim_size = int(index.max().item()) + 1

    # dispatch on type and op
    reduce = ""
    if isinstance(src, MAPTensor):
        # MAP bind == elementwise multiply → scatter-mul
        # MAP bundle == elementwise sum → use pyg scatter-sum
        reduce = "sum" if op == "bundle" else "mul"
    elif isinstance(src, HRRTensor) and op == "bundle":
        # HRR bundle == elementwise sum
        reduce = "sum"
        # HRR bind (circular conv) not supported by pyg
    if reduce:
        # When the dim_size is bigger than the addressed indexes, the scatter stacks zero vectors to reach the desired
        # dimensions, this is not correct in the hyperdimensional algebra. There we need identity vectors, such vectors
        # that when bound with a random-hypervector X, the result is X. Therefore we need to add them manually
        idx_dim = int(index.max().item()) + 1
        result = scatter(src, index, dim=0, dim_size=idx_dim, reduce=reduce)

        if (num_identity_vectors := dim_size - idx_dim) == 0:
            return result

        # TODO: Improve this
        vsa = VSAModel.HRR
        if VSAModel.MAP.value in repr(type(src)):
            vsa = VSAModel.MAP
        # elif VSAModel.BSC.value in repr(type(src)):
        #     vsa = VSAModel.BSC
        identities = torchhd.identity(
            num_vectors=num_identity_vectors, dimensions=src.shape[-1], vsa=vsa.value, device=src.device
        )
        return torch.cat([result, identities])

    # Generic fallback: group rows manually in Python (will be slower)
    # Currently no support for dim other than0
    buckets = [[] for _ in range(dim_size)]
    for i, b in enumerate(index.tolist()):
        buckets[b].append(src[i])

    # initialize output slots
    op_hd = torchhd.multibind if op == "bind" else torchhd.multibundle
    out = []
    for bucket in buckets:
        if not bucket:
            # empty bucket → identity for bind, zero for bundle
            identity = type(src).identity(1, src.shape[-1], device=src.device).squeeze(0)
            out.append(identity)
        else:
            # reduce the list by repeatedly applying op
            reduced = op_hd(torch.stack(bucket, dim=0))
            out.append(reduced)
    return torch.stack(out, dim=0)


def autograd_test_for_vsa(vsa_name: str, vsa_cls: type):
    """
    For a single VSA type (e.g. MAPTensor or HRRTensor), check that:
     - `bind(a, b)` is differentiable.
     - `bundle(a, b)` is differentiable.
     - `multibind(torch.stack([a,b,c]))` is differentiable.
     - `multibundle(torch.stack([a,b,c]))` is differentiable.
     - `scatter_hd(src, idx, op=…)` is differentiable.

    Returns True if all four ops support autograd without error, False otherwise.
    """
    torch.manual_seed(0)
    device = torch.device("cpu")

    success = True
    dim = 8*8  # very small dimensionality, just for testing

    # Create two random float‐valued hypervectors of that VSA class, requiring grad:
    # We create them by sampling from .random() or .random_like(...) if available.
    try:
        # Many torchhd classes have a .random(...) or .random_() factory method.
        a = vsa_cls.random(1, dim, requires_grad=True, device=device).squeeze(0)
        b = vsa_cls.random(1, dim, requires_grad=True, device=device).squeeze(0)
        c = vsa_cls.random(1, dim, requires_grad=True, device=device).squeeze(0)
    except AttributeError:
        # Fallback: sample floats in {-1, +1} for MAP/HRR, or {0,1} for BSC, then wrap
        vals = torch.randint(0, 2, (1, dim), device=device).float() * 2 - 1  # ±1 for MAP/HRR
        a = vsa_cls(vals.clone().requires_grad_(True))
        b = vsa_cls(vals.clone().requires_grad_(True))
        c = vsa_cls(vals.clone().requires_grad_(True))

    # Create a dummy “edge_weight” parameter we can optimize
    # shape [num_edges, 1]: we pick num_edges=3 so that scatter_hd has 2 buckets
    edge_weight = torch.randn(3, 1, requires_grad=True, device=device)

    # Prepare dummy “messages” to scatter: shape [3, dim]
    # We’ll just stack a, b, c and pretend they are messages for three edges
    # (with edge_weight gating).
    messages = torch.stack([a, b, c], dim=0)  # but a,b,c are VSA‐specific objects
    if isinstance(a, torch.Tensor):
        # Already a raw tensor
        pass
    else:
        # If a is an HV wrapper (e.g. MAPTensor), extract its data‐tensor:
        messages = torch.stack([a.data, b.data, c.data], dim=0)
    # Create an index tensor that lumps indices 0,1→bucket0 and index2→bucket1
    idx = torch.tensor([0, 0, 1], dtype=torch.long, device=device)

    # 1) Test bind(a,b)
    try:
        bound = torchhd.bind(a, b)                   #  [oai_citation:4‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/torchhd.html?utm_source=chatgpt.com) [oai_citation:5‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/getting_started.html?utm_source=chatgpt.com)
        if not bound.requires_grad:
            bound = bound.requires_grad_()
        opt = optim.Adam([bound], lr=1e-2)
        out = bound.sum()
        out.backward(retain_graph=True)
    except Exception as e:
        print(f"[{vsa_name}] bind(...) failed autograd: {e}")
        success = False

    # 2) Test bundle(a,b)
    try:
        bundled = torchhd.bundle(a, b)               #  [oai_citation:6‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/torchhd.html?utm_source=chatgpt.com) [oai_citation:7‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/getting_started.html?utm_source=chatgpt.com)
        if not bundled.requires_grad:
            bundled = bundled.requires_grad_()
        opt = optim.Adam([bundled], lr=1e-2)
        out = bundled.sum()
        out.backward(retain_graph=True)
    except Exception as e:
        print(f"[{vsa_name}] bundle(...) failed autograd: {e}")
        print(e)
        success = False

    # 3) Test multibind(torch.stack([a,b,c]))
    try:
        multi_bound = torchhd.multibind(torch.stack([a, b, c], dim=0))   #  [oai_citation:8‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/torchhd.html?utm_source=chatgpt.com) [oai_citation:9‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/getting_started.html?utm_source=chatgpt.com)
        if not multi_bound.requires_grad:
            multi_bound = multi_bound.requires_grad_()
        opt = optim.Adam([multi_bound], lr=1e-2)
        out = multi_bound.sum()
        out.backward(retain_graph=True)
    except Exception as e:
        print(f"[{vsa_name}] multibind(...) failed autograd: {e}")
        print(e)
        success = False

    # 4) Test multibundle(torch.stack([...]))
    try:
        multi_bundled = torchhd.multibundle(torch.stack([a, b, c], dim=0)) #  [oai_citation:10‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/torchhd.html?utm_source=chatgpt.com) [oai_citation:11‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/getting_started.html?utm_source=chatgpt.com)
        if not multi_bundled.requires_grad:
            multi_bundled = multi_bundled.requires_grad_()
        opt = optim.Adam([multi_bundled], lr=1e-2)
        out = multi_bundled.sum()
        out.backward(retain_graph=True)
    except Exception as e:
        print(f"[{vsa_name}] multibundle(...) failed autograd: {e}")
        print(e)
        success = False

    # 5) Test scatter_hd(messages, idx, op="bind")
    try:
        # Note: scatter_hd expects HV‐typed inputs, so wrap messages back to vsa
        hv_messages = vsa_cls(messages.clone())
        bucketed = scatter_hd(hv_messages, idx, op="bind", dim_size=2)  #  [oai_citation:12‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/torchhd.html?utm_source=chatgpt.com) [oai_citation:13‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/getting_started.html?utm_source=chatgpt.com)
        # bucketed has shape [2, dim]
        # To test gradient, we artificially add and sum:
        logits = bucketed.sum(dim=1, keepdim=True).sum()
        # Create a dummy optimizer on the logits (it will backprop through bucketed and thus through edge_weight)
        opt = optim.Adam([logits.requires_grad_(True)], lr=1e-2)
        out = logits * edge_weight.sum()  # force linkage to edge_weight
        out.backward(retain_graph=True)
    except Exception as e:
        print(f"[{vsa_name}] scatter_hd(..., op='bind') failed autograd: {e}")
        print(e)
        success = False

    # 6) Test scatter_hd(messages, idx, op="bundle")
    try:
        hv_messages = vsa_cls(messages.clone())
        bucketed2 = scatter_hd(hv_messages, idx, op="bundle", dim_size=2)  #  [oai_citation:14‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/torchhd.html?utm_source=chatgpt.com) [oai_citation:15‡torchhd.readthedocs.io](https://torchhd.readthedocs.io/en/stable/getting_started.html?utm_source=chatgpt.com)
        logits2 = bucketed2.sum(dim=1, keepdim=True).sum()
        opt = optim.Adam([logits2.requires_grad_(True)], lr=1e-2)
        out2 = logits2 * edge_weight.sum()
        out2.backward(retain_graph=True)
    except Exception as e:
        print(f"[{vsa_name}] scatter_hd(..., op='bundle') failed autograd: {e}")
        print(e)
        success = False

    return success


def main():
    print("=== TorchHD Autograd Compatibility Check ===\n")
    all_passed = True
    for name, cls in VSA_TYPES.items():
        ok = autograd_test_for_vsa(name, cls)
        print(f"{name} autograd pass: {ok}")
        if not ok:
            all_passed = False
    if all_passed:
        print("\nALL TorchHD ops support autograd for all tested VSA types.")
    else:
        print("\nSOME TorchHD ops failed autograd. See the messages above.")

main()

=== TorchHD Autograd Compatibility Check ===

[MAP] bind(...) failed autograd: can't optimize a non-leaf Tensor
[MAP] bundle(...) failed autograd: can't optimize a non-leaf Tensor
can't optimize a non-leaf Tensor
[MAP] multibind(...) failed autograd: can't optimize a non-leaf Tensor
can't optimize a non-leaf Tensor
[MAP] multibundle(...) failed autograd: can't optimize a non-leaf Tensor
can't optimize a non-leaf Tensor
[MAP] scatter_hd(..., op='bind') failed autograd: can't optimize a non-leaf Tensor
can't optimize a non-leaf Tensor
[MAP] scatter_hd(..., op='bundle') failed autograd: can't optimize a non-leaf Tensor
can't optimize a non-leaf Tensor
MAP autograd pass: False
[HRR] bind(...) failed autograd: can't optimize a non-leaf Tensor
[HRR] bundle(...) failed autograd: can't optimize a non-leaf Tensor
can't optimize a non-leaf Tensor
[HRR] multibind(...) failed autograd: can't optimize a non-leaf Tensor
can't optimize a non-leaf Tensor
[HRR] multibundle(...) failed autograd: can't o

In [14]:
import torch
import torchhd
from torchhd.tensors.map import MAPTensor

# 1) Create two _leaf_ MAP hypervectors:
leaf_a = MAPTensor.random(1, 64, requires_grad=True).squeeze(0)  # leaf
leaf_b = MAPTensor.random(1, 64, requires_grad=True).squeeze(0)  # leaf

# 2) Bind them → result is non-leaf
bound = torchhd.bind(leaf_a, leaf_b)
print(bound.is_leaf)       # False
print(bound.grad_fn)       # e.g. <AsStridedBackward0 object at 0x...>

# 3) Bundle them → result is non-leaf
bundled = torchhd.bundle(leaf_a, leaf_b)
print(bundled.is_leaf)     # False
print(bundled.grad_fn)     # e.g. <AliasBackward0 object at 0x...>

# 4) Now, propose a simple “loss” and call backward:
loss = (bound + bundled).sum()
loss.backward()            # No error! a.grad and b.grad populated

False
<AliasBackward0 object at 0x1498f90f0>
False
<AliasBackward0 object at 0x1498fa260>


In [15]:
# Continuing from above: leaf_a, leaf_b, but add leaf_c
leaf_c = MAPTensor.random(1, 64, requires_grad=True).squeeze(0)  # leaf

# 1) multibind: bind multiple leaves at once → non-leaf
multi_bound = torchhd.multibind(torch.stack([leaf_a, leaf_b, leaf_c], dim=0))
print(multi_bound.is_leaf)      # False
print(multi_bound.grad_fn)      # e.g. <AddmmBackward0 object at 0x...> (depends on implementation)

# 2) multibundle: bundle multiple leaves at once → non-leaf
multi_bundled = torchhd.multibundle(torch.stack([leaf_a, leaf_b, leaf_c], dim=0))
print(multi_bundled.is_leaf)    # False
print(multi_bundled.grad_fn)    # e.g. <SumBackward0 object at 0x...>

# 3) Backprop through them
loss2 = (multi_bound + multi_bundled).sum()
loss2.backward()  # Works fine; leaf_a.grad, leaf_b.grad, leaf_c.grad all non-null

False
<AliasBackward0 object at 0x1498f9a80>
False
<AliasBackward0 object at 0x1498f9c30>


In [17]:
# Suppose we have three leaf hypervectors (for edges) and a leaf logit parameter:
leaf_a = MAPTensor.random(1, 64, requires_grad=True).squeeze(0)
leaf_b = MAPTensor.random(1, 64, requires_grad=True).squeeze(0)
leaf_c = MAPTensor.random(1, 64, requires_grad=True).squeeze(0)
edge_logits = torch.randn(3, 1, requires_grad=True)   # leaf parameter

# Stack them into a “messages” tensor
msgs = torch.stack([leaf_a.data, leaf_b.data, leaf_c.data], dim=0)  # raw Tensor [3,64]
hv_msgs = MAPTensor(msgs)       # wrapping back into a MAPTensor (non-leaf)

# Indices to scatter into 2 buckets: first two into bucket0, last into bucket1
idx = torch.tensor([0, 0, 1], dtype=torch.long)

# 1) Scatter‐HD with bind
bucketed_bind = scatter_hd(hv_msgs, idx, op="bind", dim_size=2)
# bucketed_bind is a MAPTensor of shape [2, 64], non‐leaf

# 2) Form a scalar “loss” that depends on both bucketed_bind and edge_logits:
out = (bucketed_bind.sum(dim=-1).sum() * edge_logits.sum())
out.backward()  # No error → edge_logits.grad is non-null

# 3) Similarly, for bundle
bucketed_bundle = scatter_hd(hv_msgs, idx, op="bundle", dim_size=2)
out2 = (bucketed_bundle.sum(dim=-1).sum() * edge_logits.sum())
out2.backward()  # Works, edge_logits.grad accumulates