Skip to content

Failing to generate MLIR for Llama3 using TorchMLIR #4242

Open
@HemKava

Description

@HemKava

I downloaded Llama3 model to hf-files directory and then trying to use AutoModelForCausalLM to load the model, and then convert the transformer portion to MLIR.

huggingface-cli download meta-llama/Meta-Llama-3-8B --local-dir ./hf-files

import torch
from transformers import AutoModelForCausalLM
import torch_mlir.fx as fx
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.utils._mode_utils import no_dispatch

print("Loading model...")
full = AutoModelForCausalLM.from_pretrained(
    "./hf-files",
    local_files_only=True,
    torch_dtype="auto"
).eval()

core = full.model
print("Model loaded")

# Setup FakeTensorMode
fake_mode = FakeTensorMode()

with fake_mode, no_dispatch():
    dummy = torch.randint(0, core.config.vocab_size, (1, 16), dtype=torch.long)
    print("Fake dummy input created")

    # fx.export_and_import will internally call torch.export.export
    mlir_mod = fx.export_and_import(core, dummy)
    print("MLIR module created")

with open("llama3_transformer.mlir", "w") as f:
    f.write(str(mlir_mod))
print("Saved to llama3_transformer.mlir")

I am seeing following AssertionError error in generating the MLIR. Any pointers will be helpful:

Traceback (most recent call last):
File "hf-to-mlir3.py", line 25, in
mlir_mod = fx.export_and_import(core, dummy)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/torch-mlir/build/python_packages/torch_mlir/torch_mlir/fx.py", line 98, in export_and_import
prog = torch.export.export(
^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/init.py", line 319, in export
raise e
File "/venv/lib64/python3.11/site-packages/torch/export/init.py", line 286, in export
return _export(
^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1159, in wrapper
raise e
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1125, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 2172, in _export
ep = _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1159, in wrapper
raise e
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1125, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 2033, in _export_for_training
export_artifact = export_func(
^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1933, in _non_strict_export
) = make_fake_inputs(
^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/_export/non_strict_utils.py", line 347, in make_fake_inputs
fake_args, fake_kwargs = tree_map_with_path(
^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 2077, in tree_map_with_path
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 1197, in unflatten
leaves = list(leaves)
^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 2077, in
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
File "/venv/lib64/python3.11/site-packages/torch/_export/non_strict_utils.py", line 348, in
lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/_export/non_strict_utils.py", line 162, in fakify
fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2943, in from_tensor
return self.fake_tensor_converter.from_real_tensor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 399, in from_real_tensor
out = self.meta_converter(
^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 1913, in call
r = self.meta_tensor(
^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 894, in meta_tensor
assert not torch._C._dispatch_tls_local_exclude_set().has(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions