Skip to content

Commit

Permalink
Cohere Command-R Added (#631)
Browse files Browse the repository at this point in the history
Co-authored-by: Egor Tolmachev <t333ga@gmail.com>
  • Loading branch information
egortolmachev and t3ga committed Apr 12, 2024
1 parent b4b801c commit ce39d61
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 0 deletions.
1 change: 1 addition & 0 deletions auto_gptq/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .baichuan import BaiChuanGPTQForCausalLM
from .bloom import BloomGPTQForCausalLM
from .codegen import CodeGenGPTQForCausalLM
from .cohere import CohereGPTQForCausalLM
from .decilm import DeciLMGPTQForCausalLM
from .gemma import GemmaGPTQForCausalLM
from .gpt2 import GPT2GPTQForCausalLM
Expand Down
1 change: 1 addition & 0 deletions auto_gptq/modeling/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"deci",
"stablelm_epoch",
"mpt",
"cohere",
]
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")
Expand Down
2 changes: 2 additions & 0 deletions auto_gptq/modeling/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .baichuan import BaiChuanGPTQForCausalLM
from .bloom import BloomGPTQForCausalLM
from .codegen import CodeGenGPTQForCausalLM
from .cohere import CohereGPTQForCausalLM
from .decilm import DeciLMGPTQForCausalLM
from .gemma import GemmaGPTQForCausalLM
from .gpt2 import GPT2GPTQForCausalLM
Expand Down Expand Up @@ -40,6 +41,7 @@
"moss": MOSSGPTQForCausalLM,
"gpt_bigcode": GPTBigCodeGPTQForCausalLM,
"codegen": CodeGenGPTQForCausalLM,
"cohere": CohereGPTQForCausalLM,
"RefinedWebModel": RWGPTQForCausalLM,
"RefinedWeb": RWGPTQForCausalLM,
"falcon": RWGPTQForCausalLM,
Expand Down
20 changes: 20 additions & 0 deletions auto_gptq/modeling/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from logging import getLogger

from ._base import BaseGPTQForCausalLM


logger = getLogger(__name__)

class CohereGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "CohereDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]

__all__ = ["CohereGPTQForCausalLM"]

0 comments on commit ce39d61

Please sign in to comment.