Description
I downloaded Llama3 model to hf-files directory and then trying to use AutoModelForCausalLM to load the model, and then convert the transformer portion to MLIR.
huggingface-cli download meta-llama/Meta-Llama-3-8B --local-dir ./hf-files
import torch
from transformers import AutoModelForCausalLM
import torch_mlir.fx as fx
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.utils._mode_utils import no_dispatch
print("Loading model...")
full = AutoModelForCausalLM.from_pretrained(
"./hf-files",
local_files_only=True,
torch_dtype="auto"
).eval()
core = full.model
print("Model loaded")
# Setup FakeTensorMode
fake_mode = FakeTensorMode()
with fake_mode, no_dispatch():
dummy = torch.randint(0, core.config.vocab_size, (1, 16), dtype=torch.long)
print("Fake dummy input created")
# fx.export_and_import will internally call torch.export.export
mlir_mod = fx.export_and_import(core, dummy)
print("MLIR module created")
with open("llama3_transformer.mlir", "w") as f:
f.write(str(mlir_mod))
print("Saved to llama3_transformer.mlir")
I am seeing following AssertionError error in generating the MLIR. Any pointers will be helpful:
Traceback (most recent call last):
File "hf-to-mlir3.py", line 25, in
mlir_mod = fx.export_and_import(core, dummy)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/torch-mlir/build/python_packages/torch_mlir/torch_mlir/fx.py", line 98, in export_and_import
prog = torch.export.export(
^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/init.py", line 319, in export
raise e
File "/venv/lib64/python3.11/site-packages/torch/export/init.py", line 286, in export
return _export(
^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1159, in wrapper
raise e
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1125, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 2172, in _export
ep = _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1159, in wrapper
raise e
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1125, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 2033, in _export_for_training
export_artifact = export_func(
^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/export/_trace.py", line 1933, in _non_strict_export
) = make_fake_inputs(
^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/_export/non_strict_utils.py", line 347, in make_fake_inputs
fake_args, fake_kwargs = tree_map_with_path(
^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 2077, in tree_map_with_path
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 1197, in unflatten
leaves = list(leaves)
^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/utils/_pytree.py", line 2077, in
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
File "/venv/lib64/python3.11/site-packages/torch/_export/non_strict_utils.py", line 348, in
lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/_export/non_strict_utils.py", line 162, in fakify
fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2943, in from_tensor
return self.fake_tensor_converter.from_real_tensor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 399, in from_real_tensor
out = self.meta_converter(
^^^^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 1913, in call
r = self.meta_tensor(
^^^^^^^^^^^^^^^^^
File "/venv/lib64/python3.11/site-packages/torch/_subclasses/meta_utils.py", line 894, in meta_tensor
assert not torch._C._dispatch_tls_local_exclude_set().has(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError