Skip to content

[BUG] @jit_fuser fails with Unknown type constructor SequenceΒ #880

@Edenzzzz

Description

@Edenzzzz

Describe the bug
Using torch 2.1.1, running bash examples/bert/train_bert_340m_distributed.sh produces JIT error due to the Sequence annotator in calculate_logits_max

return torch.jit.script(fn, _rcb=rcb)  File "/root/sharedDisk/home/tanwenxuan/miniconda3/lib/python3.8/site-packages/torch/jit/_script.py", line 1381, in script

  File "/root/sharedDisk/home/tanwenxuan/miniconda3/lib/python3.8/site-packages/torch/jit/_script.py", line 1381, in script
    return torch.jit.script(fn, _rcb=rcb)
    return torch.jit.script(fn, _rcb=rcb)  File "/root/sharedDisk/home/tanwenxuan/miniconda3/lib/python3.8/site-packages/torch/jit/_script.py", line 1381, in script
    
return torch.jit.script(fn, _rcb=rcb)  File "/root/sharedDisk/home/tanwenxuan/miniconda3/lib/python3.8/site-packages/torch/jit/_script.py", line 1381, in script

  File "/root/sharedDisk/home/tanwenxuan/miniconda3/lib/python3.8/site-packages/torch/jit/_script.py", line 1381, in script
    fn = torch._C._jit_script_compile(
RuntimeError: 
Unknown type constructor Sequence:
  File "/root/sharedDisk/home/tanwenxuan/Megatron-LM/megatron/core/tensor_parallel/utils.py", line 106
    def vocab_range_from_per_partition_vocab_size(
        per_partition_vocab_size: int, rank, world_size: int
    ) -> Sequence[int]:
         ~~~~~~~~~~~~~ <--- HERE
        index_f = rank * per_partition_vocab_size
        index_l = index_f + per_partition_vocab_size
'vocab_range_from_per_partition_vocab_size' is being compiled since it was called from 'calculate_predicted_logits'
  File "/root/sharedDisk/home/tanwenxuan/Megatron-LM/megatron/core/tensor_parallel/cross_entropy.py", line 41
    
        # Get the partition's vocab indices
        get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        partition_vocab_size = vocab_parallel_logits.size()[-1]
        rank = get_tensor_model_parallel_rank()
'calculate_predicted_logits' is being compiled since it was called from 'calculate_predicted_logits'

To Reproduce
bash examples/bert/train_bert_340m_distributed.sh

Expected behavior

Stack trace/logs

Environment (please complete the following information):

  • Megatron-LM e33c8f7
  • PyTorch 2.1.1
  • CUDA 12.1

Proposed fix
Disable jit if this occurs
Additional context
Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions