Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG/DEPRECATION] Remove fused attention/mlp #659

Closed
wants to merge 91 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
b33e1d0
remove (zeors -= 1)
qwopqwop200 Feb 21, 2024
6419a2d
add warning
qwopqwop200 Feb 21, 2024
5ea98bb
support backwards compatibility
qwopqwop200 Feb 22, 2024
0b07292
support and fix bug
qwopqwop200 Feb 25, 2024
b015ae9
remove not necessary parm
qwopqwop200 Feb 26, 2024
b7d9ade
fix test_q4 bug
qwopqwop200 Mar 2, 2024
83e510e
fix test_q4 bug
qwopqwop200 Mar 2, 2024
a89cb77
fix bug double converting
qwopqwop200 Mar 2, 2024
639e66a
Update _utils.py
qwopqwop200 Mar 2, 2024
15ecb0f
Merge branch 'main' into main
qwopqwop200 Mar 2, 2024
0759aea
merge main #1
Qubitium Apr 12, 2024
101647d
FIX type error
LRL-ModelCloud Apr 12, 2024
28a2541
module is nn.Module
LRL-ModelCloud Apr 12, 2024
3899d5b
sync name
Qubitium Apr 12, 2024
9f25c21
need return module
LRL-ModelCloud Apr 12, 2024
4159474
modify default format to gptq_v2
LRL-ModelCloud Apr 12, 2024
c853143
fix need return model
LRL-ModelCloud Apr 12, 2024
047dd97
remove fixme and default to gptq_v2 for quantize_config
Qubitium Apr 12, 2024
869a162
save _qlinear_kernel and allow save to older format
Qubitium Apr 12, 2024
365d961
fix name
Qubitium Apr 12, 2024
e363d85
pass quantize
Qubitium Apr 12, 2024
6af41db
Merge remote-tracking branch 'origin/main' into sym-false
Qubitium Apr 12, 2024
bd12bde
update
Qubitium Apr 12, 2024
cd916fe
store quant log/stats in dict slice and return to user in quantize()
Qubitium Apr 13, 2024
d35e57d
accept saved quant_log in quantize() and calculate diff
Qubitium Apr 13, 2024
b0d2ad5
tqdm the layer loop
Qubitium Apr 14, 2024
45cbc6b
log awq vs autogptq outputs in awq compat test
Qubitium Apr 14, 2024
23b4b65
fix cached models is not compatible with new pr. add v2 to cache file…
Qubitium Apr 14, 2024
dc5d3dd
add deprecation warning for loading .bin/.pt weights
Qubitium Apr 14, 2024
fec3f1c
add missing termcolor req
Qubitium Apr 14, 2024
2921b27
spell
Qubitium Apr 14, 2024
7b97f73
fix triton v2
Qubitium Apr 14, 2024
c5c98da
rename quant log column 'name' to 'module'
Qubitium Apr 14, 2024
5b95194
ruff
Qubitium Apr 14, 2024
19189aa
add quantization tests for sym=False
Qubitium Apr 14, 2024
070e6c8
spell
Qubitium Apr 14, 2024
0103785
fix type hint
Qubitium Apr 14, 2024
61a8713
more testing, fix serialization bug, no additional dependency
fxmarty Apr 15, 2024
c6b3632
fix version
Qubitium Apr 16, 2024
83e932d
no need for ... in tqdm
Qubitium Apr 16, 2024
a3863a4
Use threadpoolctl to limit packing threads
Qubitium Apr 16, 2024
7b89b89
layer # sync with visual tqdm
Qubitium Apr 16, 2024
ed0fe60
use thread limit 1: as good as 4 and 1 beats 16 threads in testing
Qubitium Apr 16, 2024
37ed02e
fix saving of gptq (v1)
Qubitium Apr 16, 2024
5d66558
deep copy
Qubitium Apr 16, 2024
9d88639
remove todo: verified
Qubitium Apr 16, 2024
0b84da7
TEST/DEBUG underflow protection and output underflow stats
Qubitium Apr 18, 2024
9e3fea0
force underflow math (testing shows this is better than skipping math)
Qubitium Apr 18, 2024
005a942
1) disable serialization of sym=False to v1 by default. 2) disable lo…
Qubitium Apr 22, 2024
365618d
revert adding underflow check/stats
Qubitium Apr 23, 2024
31d4027
pass test_quant and test both v1 and v2 save/load
Qubitium Apr 23, 2024
ac63033
performance fix for convert_v1/v2().
Qubitium Apr 24, 2024
e8eabe2
need to ++ version so can delimit models make pre/post pr
Qubitium Apr 24, 2024
e4936f5
add meta and meta.quantizer to quantized_config.json
Qubitium Apr 24, 2024
c822d63
fix json save and add meta check to test_quantization. distutils is d…
Qubitium Apr 24, 2024
35cc701
fix failed test
Qubitium Apr 24, 2024
cddbe23
fix awq unpack/repacking thread regression
Qubitium Apr 24, 2024
811ca39
remove highly flaky mistral tiny test with input/output that is nonse…
Qubitium Apr 24, 2024
2104cae
now we can detect quant producer, we don't need use_unsafe_math for l…
Qubitium Apr 24, 2024
bc2bf5b
updat tests
Qubitium Apr 24, 2024
528a8fc
default to gptq v1 for max compat and remove use_unsafe_math check in…
Qubitium Apr 24, 2024
464fc7e
misc
Qubitium Apr 25, 2024
59be4b3
separate the concept of meta.quantizer and meta.packer (intel/auto-ro…
Qubitium Apr 25, 2024
79050ea
clean
Qubitium Apr 25, 2024
4450702
test allow loading quantized lm_head
Qubitium Apr 25, 2024
79a2f76
test loading of quantized lm_head
Qubitium Apr 25, 2024
d85f3f2
rename
Qubitium Apr 25, 2024
55336b5
fix quantized lm_head loading
Qubitium Apr 26, 2024
64c604c
fix quantized lm_head loading
Qubitium Apr 26, 2024
842fd3c
update
Qubitium Apr 26, 2024
d086437
sync with main
Qubitium Apr 26, 2024
7b1d115
add unittest test_lm_head
Qubitium Apr 27, 2024
8ac1b87
backport h100 fixed marlin kernel from vllm
Qubitium Apr 27, 2024
25decb0
Revert "backport h100 fixed marlin kernel from vllm"
Qubitium Apr 27, 2024
aa2e385
revert
Qubitium Apr 27, 2024
e474efd
fix h100
Qubitium Apr 27, 2024
5b307da
revert debug code
Qubitium Apr 27, 2024
5b456a6
now that h100 is validated, remove hopper check
Qubitium Apr 27, 2024
7f031b3
warn users if quantization using insufficient nsamples
Qubitium Apr 25, 2024
96e0b05
remove fused attention/mlp
Qubitium Apr 28, 2024
7ca7260
continue
Qubitium Apr 28, 2024
fe191f5
ADD GLM model support
LRL-ModelCloud Jun 6, 2024
806d8d0
ADD GLM model support
LRL-ModelCloud Jun 6, 2024
847b27e
Merge pull request #1 from Qubitium/sym-false-lm-head
Qubitium Jun 15, 2024
fc26717
Merge pull request #2 from Qubitium/add-glm-v2
Qubitium Jun 15, 2024
d27f26f
Merge pull request #4 from Qubitium/fix-h100
Qubitium Jun 15, 2024
8dad030
Merge branch 'main' into quantize-lm-head
Qubitium Jun 15, 2024
ea58a81
Merge pull request #3 from Qubitium/quantize-lm-head
Qubitium Jun 15, 2024
5e9ca27
Merge branch 'main' into nsamples-sanity-check
Qubitium Jun 15, 2024
647d228
Merge pull request #5 from Qubitium/nsamples-sanity-check
Qubitium Jun 15, 2024
9b1e155
Merge branch 'main' into remove-fused-attention
Qubitium Jun 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions auto_gptq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from .modeling import AutoGPTQForCausalLM, BaseQuantizeConfig
from .utils.exllama_utils import exllama_set_max_input_length
from .utils.peft_utils import get_gptq_peft_model


__version__ = "0.8.0.dev0"
from .version import __version__
1 change: 1 addition & 0 deletions auto_gptq/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .cohere import CohereGPTQForCausalLM
from .decilm import DeciLMGPTQForCausalLM
from .gemma import GemmaGPTQForCausalLM
from .chatglm import ChatGLMForCausalLM
from .gpt2 import GPT2GPTQForCausalLM
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .gpt_neox import GPTNeoXGPTQForCausalLM
Expand Down
406 changes: 208 additions & 198 deletions auto_gptq/modeling/_base.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions auto_gptq/modeling/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"moss",
"gpt_bigcode",
"codegen",
"chatglm",
"RefinedWebModel",
"RefinedWeb",
"baichuan",
Expand All @@ -24,7 +25,6 @@
"deci",
"stablelm_epoch",
"mpt",
"cohere",
]
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")
Expand All @@ -42,8 +42,9 @@
SUPPORTED_MODELS.append("phi")
if compare_transformers_version("v4.38.0", op="ge"):
SUPPORTED_MODELS.append("gemma")
if compare_transformers_version("v4.39.0.dev0", op="ge"):
if compare_transformers_version("v4.39.0", op="ge"):
SUPPORTED_MODELS.append("starcoder2")
SUPPORTED_MODELS.append("cohere")

EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048

Expand Down
116 changes: 93 additions & 23 deletions auto_gptq/modeling/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List, Optional, Union

import accelerate
import threadpoolctl as tctl
import numpy as np
import torch
import torch.nn as nn
Expand All @@ -13,6 +14,7 @@
from transformers import AutoConfig
from transformers.utils.hub import cached_file

from ..quantization import BaseQuantizeConfig
from ..utils.import_utils import dynamically_import_QuantLinear
from ..utils.modeling_utils import recurse_setattr
from ._const import CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, SUPPORTED_MODELS
Expand Down Expand Up @@ -147,6 +149,71 @@ def make_quant(
recurse_setattr(module, name, new_layer.to(ori_layer_device))


def convert_gptq_v1_to_v2_format(
model,
quantize_config: BaseQuantizeConfig,
qlinear_kernel: nn.Module,
):
use_qigen = qlinear_kernel.QUANT_TYPE == "qigen"

# Limit thread usage to avoid auto-parallizataion regression
with tctl.threadpool_limits(limits=1):
for _, submodule in model.named_modules():
# v1 checkpoint format used to do `qzeros = qzeros -= 1` before serialization, thus the
# additions here do not overflow.
# v1 checkpoint format with sym=False saved via convert_gptq_v2_to_v1_format() will
# overflow ~<=13% based on testing
if isinstance(submodule, qlinear_kernel):
if use_qigen:
submodule.zeros.data += 1
else:
if quantize_config.bits == 2:
submodule.qzeros.data += 0b01010101010101010101010101010101
elif quantize_config.bits == 3:
submodule.qzeros.data[:,range(0,submodule.qzeros.data.shape[1],3)] += 0b00100100100100100100100100100100
submodule.qzeros.data[:,range(1,submodule.qzeros.data.shape[1],3)] += 0b10010010010010010010010010010010
submodule.qzeros.data[:,range(2,submodule.qzeros.data.shape[1],3)] += 0b01001001001001001001001001001001
elif quantize_config.bits == 4:
submodule.qzeros.data += 0b00010001000100010001000100010001
elif quantize_config.bits == 8:
submodule.qzeros.data += 0b00000001000000010000000100000001
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")

return model


def convert_gptq_v2_to_v1_format(
model,
quantize_config: BaseQuantizeConfig,
qlinear_kernel: nn.Module,
):
use_qigen = qlinear_kernel.QUANT_TYPE == "qigen"

# Limit thread usage to avoid auto-parallizataion regression
with tctl.threadpool_limits(limits=1):
for _, submodule in model.named_modules():
# sym=False has underflow probability of ~<=13% during testing. No underflow possible for sym=True.
if isinstance(submodule, qlinear_kernel):
if use_qigen:
submodule.zeros.data -= 1
else:
if quantize_config.bits == 2:
submodule.qzeros.data -= 0b01010101010101010101010101010101
elif quantize_config.bits == 3:
submodule.qzeros.data[:,range(0,submodule.qzeros.data.shape[1],3)] -= 0b00100100100100100100100100100100
submodule.qzeros.data[:,range(1,submodule.qzeros.data.shape[1],3)] -= 0b10010010010010010010010010010010
submodule.qzeros.data[:,range(2,submodule.qzeros.data.shape[1],3)] -= 0b01001001001001001001001001001001
elif quantize_config.bits == 4:
submodule.qzeros.data -= 0b00010001000100010001000100010001
elif quantize_config.bits == 8:
submodule.qzeros.data -= 0b00000001000000010000000100000001
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")

return model


def preprocess_checkpoint_qigen(
module,
names,
Expand Down Expand Up @@ -297,32 +364,36 @@ def pack_model(
)
qlayers = find_layers(model, [QuantLinear])

pbar = tqdm(qlayers.keys(), leave=True)
for name in pbar:
pbar.set_description(f"Packing {name}...", refresh=True)

quantizers[name], scale, zero, g_idx = quantizers[name]
# so far can only pack layer on CPU
layer_device = qlayers[name].device
qlayers[name].to(CPU)
layers[name], scale, zero, g_idx = (
layers[name].to(CPU),
scale.to(CPU),
zero.to(CPU),
g_idx.to(CPU),
)
if QuantLinear.QUANT_TYPE == "marlin":
qlayers[name].pack(layers[name], scale)
else:
qlayers[name].pack(layers[name], scale, zero, g_idx)
qlayers[name].to(layer_device)
# Limit pack() thread usage to avoid auto-parallizataion regression
with tctl.threadpool_limits(limits=1):
pbar = tqdm(qlayers.keys(), leave=True)
for name in pbar:
pbar.set_description(f"Packing {name}")

quantizers[name], scale, zero, g_idx = quantizers[name]
# so far can only pack layer on CPU
layer_device = qlayers[name].device
qlayers[name].to(CPU)
layers[name], scale, zero, g_idx = (
layers[name].to(CPU),
scale.to(CPU),
zero.to(CPU),
g_idx.to(CPU),
)
if QuantLinear.QUANT_TYPE == "marlin":
qlayers[name].pack(layers[name], scale)
else:
qlayers[name].pack(layers[name], scale, zero, g_idx)
qlayers[name].to(layer_device)

logger.info("Model packed.")

if use_triton and warmup_triton:
logger.warning(
"using autotune_warmup will move model to GPU, make sure you have enough VRAM to load the whole model."
)
QuantLinear.warmup(model.to(CUDA_0), seqlen=model.seqlen)
return QuantLinear


def check_and_get_model_type(model_dir, trust_remote_code=False):
Expand Down Expand Up @@ -475,7 +546,7 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllama":
submodule.post_init()

## exllamav2
# exllamav2
fixed_bytes = {}
model_uses_exllamav2 = False

Expand Down Expand Up @@ -574,8 +645,6 @@ def unpack_awq(
torch.int16 if bits == 8 else torch.int8
)

# zeros = zeros + 1

torch.bitwise_and(zeros, (2**bits) - 1, out=zeros)

zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
Expand Down Expand Up @@ -671,7 +740,6 @@ def pack_from_tensors(
qweight = qweight.astype(np.int32)
qweight = torch.from_numpy(qweight)

unpacked_qzeros = unpacked_qzeros - 1
torch.bitwise_and(unpacked_qzeros, (2**bits) - 1, out=unpacked_qzeros)

unpacked_qzeros = unpacked_qzeros.numpy().astype(np.uint32)
Expand Down Expand Up @@ -775,4 +843,6 @@ def get_checkpoints(model_name_or_path: str, extensions: List[str], possible_mod
"check_and_get_model_type",
"simple_dispatch_model",
"make_sure_no_tensor_in_meta_device",
"convert_gptq_v1_to_v2_format",
"convert_gptq_v2_to_v1_format",
]
6 changes: 2 additions & 4 deletions auto_gptq/modeling/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .cohere import CohereGPTQForCausalLM
from .decilm import DeciLMGPTQForCausalLM
from .gemma import GemmaGPTQForCausalLM
from .chatglm import ChatGLMForCausalLM
from .gpt2 import GPT2GPTQForCausalLM
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .gpt_neox import GPTNeoXGPTQForCausalLM
Expand Down Expand Up @@ -39,6 +40,7 @@
"llama": LlamaGPTQForCausalLM,
"opt": OPTGPTQForCausalLM,
"moss": MOSSGPTQForCausalLM,
"chatglm": ChatGLMForCausalLM,
"gpt_bigcode": GPTBigCodeGPTQForCausalLM,
"codegen": CodeGenGPTQForCausalLM,
"cohere": CohereGPTQForCausalLM,
Expand Down Expand Up @@ -98,8 +100,6 @@ def from_quantized(
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
use_triton: bool = False,
inject_fused_attention: bool = False,
inject_fused_mlp: bool = False,
use_cuda_fp16: bool = True,
quantize_config: Optional[BaseQuantizeConfig] = None,
model_basename: Optional[str] = None,
Expand Down Expand Up @@ -148,8 +148,6 @@ def from_quantized(
device=device,
low_cpu_mem_usage=low_cpu_mem_usage,
use_triton=use_triton,
inject_fused_attention=inject_fused_attention,
inject_fused_mlp=inject_fused_mlp,
use_cuda_fp16=use_cuda_fp16,
quantize_config=quantize_config,
model_basename=model_basename,
Expand Down
15 changes: 15 additions & 0 deletions auto_gptq/modeling/chatglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from ._base import BaseGPTQForCausalLM


class ChatGLMForCausalLM(BaseGPTQForCausalLM):
layer_type = "GLMBlock"
layers_block_name = "transformer.encoder.layers"
outside_layer_modules = ["transformer.embedding.word_embeddings", "transformer.output_layer"]
inside_layer_modules = [
["self_attention.query_key_value"],
["self_attention.dense"],
["mlp.dense_h_to_4h"],
["mlp.dense_4h_to_h"],
]

__all__ = ["ChatGLMForCausalLM"]
1 change: 0 additions & 1 deletion auto_gptq/modeling/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@ class CohereGPTQForCausalLM(BaseGPTQForCausalLM):
]

__all__ = ["CohereGPTQForCausalLM"]

10 changes: 0 additions & 10 deletions auto_gptq/modeling/decilm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
from ._base import BaseGPTQForCausalLM


if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None

logger = getLogger(__name__)


Expand All @@ -25,8 +18,5 @@ class DeciLMGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"],
]

fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel


__all__ = ["DeciLMGPTQForCausalLM"]
3 changes: 0 additions & 3 deletions auto_gptq/modeling/gptj.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ..nn_modules.fused_gptj_attn import FusedGPTJAttentionForQuantizedModel
from ._base import BaseGPTQForCausalLM


Expand All @@ -13,7 +12,5 @@ class GPTJGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.fc_out"],
]

fused_attn_module_type = FusedGPTJAttentionForQuantizedModel


__all__ = ["GPTJGPTQForCausalLM"]
11 changes: 0 additions & 11 deletions auto_gptq/modeling/llama.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
from logging import getLogger

from ..utils.import_utils import compare_transformers_version
from ._base import BaseGPTQForCausalLM


if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None

logger = getLogger(__name__)


Expand All @@ -25,8 +17,5 @@ class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"],
]

fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel


__all__ = ["LlamaGPTQForCausalLM"]
10 changes: 0 additions & 10 deletions auto_gptq/modeling/longllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
from ._base import BaseGPTQForCausalLM


if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None

logger = getLogger(__name__)


Expand All @@ -25,8 +18,5 @@ class LongLlamaGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"],
]

fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel


__all__ = ["LongLlamaGPTQForCausalLM"]
2 changes: 1 addition & 1 deletion auto_gptq/modeling/mpt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from auto_gptq.modeling import BaseGPTQForCausalLM
from ._base import BaseGPTQForCausalLM


class MPTGPTQForCausalLM(BaseGPTQForCausalLM):
Expand Down
10 changes: 0 additions & 10 deletions auto_gptq/modeling/stablelmepoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
from ._base import BaseGPTQForCausalLM


if compare_transformers_version("v4.28.0", op="ge"):
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
else:
FusedLlamaAttentionForQuantizedModel = None
FusedLlamaMLPForQuantizedModel = None

logger = getLogger(__name__)


Expand All @@ -25,8 +18,5 @@ class StableLMEpochGPTQForCausalLM(BaseGPTQForCausalLM):
["mlp.down_proj"],
]

fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel


__all__ = ["StableLMEpochGPTQForCausalLM"]
Loading