Skip to content

Conversation

@erichuang-cienet
Copy link
Contributor

I encountered a ConfigAttributeError while running the test_spmd.py test. After I fixed the ConfigAttributeError, a new error appeared: cannot import name '_histogram' from 'torch_xla.experimental.custom_kernel'.

I found that the _histogram function had been removed from the torch_xla repository, so I copied the deleted function into torchprime to pass the test for now.

Environment:

  • TPU VM: v6e-8
  • Python 3.11
  • torch 2.9.0.dev20250825+cpu
  • torch-xla 2.9.0+git8243a25

Test command:

pytest -v torchprime/torch_xla_models/tests/test_spmd.py

Original error:

torchprime/torch_xla_models/tests/test_spmd.py:413: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
torchprime/torch_xla_models/tests/test_spmd.py:426: in assert_same_output_weights_grad
    assert_output(model_config_sharded, model_fsdp_v2_sharded, input, labels)
torchprime/torch_xla_models/tests/test_spmd.py:454: in assert_output
    config_logits, config_loss = model_config_sharded(
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.11/site-packages/torch_xla/debug/profiler.py:190: in wrapper_trace_me
    return func(*args, **kwargs)
torchprime/torch_xla_models/model/mixtral/model.py:985: in forward
    outputs = self.model(
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.11/site-packages/torch_xla/debug/profiler.py:190: in wrapper_trace_me
    return func(*args, **kwargs)
torchprime/torch_xla_models/model/mixtral/model.py:951: in forward
    hidden_states, load_balance_loss = self.layers(
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
torchprime/layers/sequential.py:34: in forward
    input = module(*splat(input), **broadcasted_inputs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
torchprime/sharding/shard_model.py:317: in forward
    return self.mark_sharding(self._orig_mod(*args, **kwargs), self.spec)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.11/site-packages/torch_xla/debug/profiler.py:190: in wrapper_trace_me
    return func(*args, **kwargs)
torchprime/torch_xla_models/model/mixtral/model.py:869: in forward
    hidden_states = self.self_attn(
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.11/site-packages/torch_xla/debug/profiler.py:190: in wrapper_trace_me
    return func(*args, **kwargs)
torchprime/torch_xla_models/model/mixtral/model.py:235: in forward
    attn_output = self.attention_block(
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.11/site-packages/torch_xla/debug/profiler.py:190: in wrapper_trace_me
    return func(*args, **kwargs)
torchprime/torch_xla_models/attention.py:46: in forward
    if self.config.attention_kernel != "splash_attention":
venv/lib/python3.11/site-packages/omegaconf/dictconfig.py:355: in __getattr__
    self._format_and_raise(
venv/lib/python3.11/site-packages/omegaconf/base.py:231: in _format_and_raise
    format_and_raise(
venv/lib/python3.11/site-packages/omegaconf/_utils.py:899: in format_and_raise
    _raise(ex, cause)
venv/lib/python3.11/site-packages/omegaconf/_utils.py:797: in _raise
    raise ex.with_traceback(sys.exc_info()[2])  # set env var OC_CAUSE=1 for full trace
venv/lib/python3.11/site-packages/omegaconf/dictconfig.py:351: in __getattr__
    return self._get_impl(
venv/lib/python3.11/site-packages/omegaconf/dictconfig.py:442: in _get_impl
    node = self._get_child(
venv/lib/python3.11/site-packages/omegaconf/basecontainer.py:73: in _get_child
    child = self._get_node(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = {'vocab_size': 32000, 'hidden_size': 4096, 'initializer_range': 0.02, 'intermediate_size': 14336, 'num_hidden_layers':...s_coef': 0.02, 'attention_bias': False, 'attention_dropout': 0.0, 'flash_attention': True, 'moe_implementation': 'gmm'}
key = 'attention_kernel', validate_access = True, validate_key = False, throw_on_missing_value = False, throw_on_missing_key = True

    def _get_node(
        self,
        key: DictKeyType,
        validate_access: bool = True,
        validate_key: bool = True,
        throw_on_missing_value: bool = False,
        throw_on_missing_key: bool = False,
    ) -> Optional[Node]:
        try:
            key = self._validate_and_normalize_key(key)
        except KeyValidationError:
            if validate_access and validate_key:
                raise
            else:
                if throw_on_missing_key:
                    raise ConfigAttributeError
                else:
                    return None
    
        if validate_access:
            self._validate_get(key)
    
        value: Optional[Node] = self.__dict__["_content"].get(key)
        if value is None:
            if throw_on_missing_key:
>               raise ConfigKeyError(f"Missing key {key!s}")
E               omegaconf.errors.ConfigAttributeError: Missing key attention_kernel
E                   full_key: attention_kernel
E                   object_type=dict

venv/lib/python3.11/site-packages/omegaconf/dictconfig.py:480: ConfigAttributeError

Error after adding the attention kernel:

torchprime/torch_xla_models/tests/test_spmd.py:414: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
torchprime/torch_xla_models/tests/test_spmd.py:427: in assert_same_output_weights_grad
    assert_output(model_config_sharded, model_fsdp_v2_sharded, input, labels)
torchprime/torch_xla_models/tests/test_spmd.py:455: in assert_output
    config_logits, config_loss = model_config_sharded(
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.11/site-packages/torch_xla/debug/profiler.py:190: in wrapper_trace_me
    return func(*args, **kwargs)
torchprime/torch_xla_models/model/mixtral/model.py:985: in forward
    outputs = self.model(
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.11/site-packages/torch_xla/debug/profiler.py:190: in wrapper_trace_me
    return func(*args, **kwargs)
torchprime/torch_xla_models/model/mixtral/model.py:951: in forward
    hidden_states, load_balance_loss = self.layers(
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
torchprime/layers/sequential.py:34: in forward
    input = module(*splat(input), **broadcasted_inputs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
torchprime/sharding/shard_model.py:317: in forward
    return self.mark_sharding(self._orig_mod(*args, **kwargs), self.spec)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.11/site-packages/torch_xla/debug/profiler.py:190: in wrapper_trace_me
    return func(*args, **kwargs)
torchprime/torch_xla_models/model/mixtral/model.py:879: in forward
    hidden_states, router_logits, loss = self.block_sparse_moe(hidden_states)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.11/site-packages/torch_xla/debug/profiler.py:190: in wrapper_trace_me
    return func(*args, **kwargs)
torchprime/torch_xla_models/model/mixtral/model.py:829: in forward
    final_hidden_states = self.experts(hidden_states, selected_experts)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
venv/lib/python3.11/site-packages/torch_xla/debug/profiler.py:190: in wrapper_trace_me
    return func(*args, **kwargs)
torchprime/torch_xla_models/model/mixtral/model.py:635: in forward
    return Gmm.apply(hidden_states, top_ks, self.w1, self.w2, self.w3)
venv/lib/python3.11/site-packages/torch/autograd/function.py:581: in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
venv/lib/python3.11/site-packages/torch_xla/debug/profiler.py:190: in wrapper_trace_me
    return func(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

ctx = <torch.autograd.function.GmmBackward object at 0x76bd92d9aad0>
hidden_states = tensor([[ 0.0376, -1.4229, -1.1879,  ...,  0.9058, -0.7395, -2.2306],
        [ 0.2465,  0.2539, -0.5946,  ...,  1.363...
        [ 1.4304,  1.0818,  0.3875,  ...,  0.2306, -0.1513,  1.0394]],
       device='xla:0', grad_fn=<ViewBackward0>)
top_ks = tensor([[2, 0],
        [7, 3],
        [0, 5],
        ...,
        [4, 5],
        [4, 7],
        [6, 4]], device='xla:0')
w1 = Parameter containing:
tensor([[[ 8.5701e-05, -3.3926e-05, -2.0257e-05,  ..., -4.2514e-05,
          -6.8237e-05, -1.64... 1.0675e-04, -6.6270e-05,  ..., -3.2797e-05,
           1.0043e-04, -1.2965e-04]]], device='xla:0', requires_grad=True)
w2 = Parameter containing:
tensor([[[-1.8818e-05, -1.0192e-04,  7.8132e-05,  ...,  1.1593e-04,
          -2.3545e-05,  1.28... 4.9709e-05,  5.0439e-05,  ..., -2.1213e-05,
          -9.3021e-05,  1.0746e-04]]], device='xla:0', requires_grad=True)
w3 = Parameter containing:
tensor([[[ 3.1825e-06,  8.2731e-05,  9.9913e-06,  ..., -1.1615e-04,
           3.0409e-05,  1.04... 1.0521e-04, -8.7700e-05,  ..., -1.1096e-04,
           8.7636e-05,  1.2744e-04]]], device='xla:0', requires_grad=True)

    @staticmethod
    @xp.trace_me("gmm_forward")
    def forward(
      ctx,
      hidden_states: torch.Tensor,
      top_ks: torch.Tensor,
      w1: torch.Tensor,
      w2: torch.Tensor,
      w3: torch.Tensor,
    ) -> torch.Tensor:
      """
      Integrated with PyTorch/XLA Pallas gmm:
    
      lhs: [m, hidden_size]
      top_ks: [m, k]
      w1: [num_experts, hidden_size, ffn_dim]
      w2: [num_experts, ffn_dim, hidden_size]
      w3: [num_experts, hidden_size, ffn_dim]
      """
>     from torch_xla.experimental.custom_kernel import _histogram, gmm
E     ImportError: cannot import name '_histogram' from 'torch_xla.experimental.custom_kernel' (/home/sa_106246016191652103077/torchprime/venv/lib/python3.11/site-packages/torch_xla/experimental/custom_kernel.py)

torchprime/torch_xla_models/model/mixtral/model.py:355: ImportError

@erichuang-cienet erichuang-cienet force-pushed the erichuang/fix-test_spmd branch 2 times, most recently from 11ab27f to e86579a Compare September 1, 2025 01:36
@jack8558
Copy link
Collaborator

jack8558 commented Sep 3, 2025

Could you pull the latest main to see if the CI passes?

@jack8558 jack8558 enabled auto-merge (squash) September 12, 2025 16:03
@jack8558 jack8558 disabled auto-merge September 15, 2025 16:20
@jack8558 jack8558 merged commit f5c6d08 into AI-Hypercomputer:main Sep 15, 2025
9 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants