Skip to content

Commit

Permalink
Add support for MPT (#73)
Browse files Browse the repository at this point in the history
* Add initial support for MPT

* Ruff and stuff
  • Loading branch information
LaaZa committed Mar 28, 2024
1 parent ca187fb commit a9f3d46
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 0 deletions.
1 change: 1 addition & 0 deletions auto_gptq/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .mistral import MistralGPTQForCausalLM
from .mixtral import MixtralGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM
from .mpt import MPTGPTQForCausalLM
from .opt import OPTGPTQForCausalLM
from .phi import PhiGPTQForCausalLM
from .qwen import QwenGPTQForCausalLM
Expand Down
1 change: 1 addition & 0 deletions auto_gptq/modeling/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"xverse",
"deci",
"stablelm_epoch",
"mpt",
]
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")
Expand Down
2 changes: 2 additions & 0 deletions auto_gptq/modeling/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .mistral import MistralGPTQForCausalLM
from .mixtral import MixtralGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM
from .mpt import MPTGPTQForCausalLM
from .opt import OPTGPTQForCausalLM
from .phi import PhiGPTQForCausalLM
from .qwen import QwenGPTQForCausalLM
Expand Down Expand Up @@ -56,6 +57,7 @@
"longllama": LongLlamaGPTQForCausalLM,
"gemma": GemmaGPTQForCausalLM,
"phi": PhiGPTQForCausalLM,
"mpt": MPTGPTQForCausalLM,
}


Expand Down
19 changes: 19 additions & 0 deletions auto_gptq/modeling/mpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from auto_gptq.modeling import BaseGPTQForCausalLM


class MPTGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MPTBlock"
layers_block_name = "transformer.blocks"
outside_layer_modules = [
"transformer.wte", "transformer.norm_f"
]

inside_layer_modules = [
["attn.Wqkv"],
["attn.out_proj"],
["ffn.up_proj"],
["ffn.down_proj"]
]


__all__ = ["MPTGPTQForCausalLM"]

0 comments on commit a9f3d46

Please sign in to comment.