Skip to content

Commit

Permalink
Support for Marlin kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov committed Feb 18, 2024
1 parent bf9601c commit 23cd55f
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 96 deletions.
102 changes: 69 additions & 33 deletions generate/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

# TODO: remove it
# flake8: noqa
import importlib
import json
import sys
import time
Expand Down Expand Up @@ -104,8 +107,20 @@ def main(
top_k: Optional[int] = 200,
temperature: float = 0.8,
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq"]] = None,
kernel: Optional[Literal["cuda", "exllama", "exllamav2", "triton"]] = None,
quantize: Optional[
Literal[
"bnb.nf4",
"bnb.nf4-dq",
"bnb.fp4",
"bnb.fp4-dq",
"bnb.int8",
"gptq.int2",
"gptq.int3",
"gptq.int4",
"gptq.int8",
]
] = None,
kernel: Optional[Literal["cuda_old", "cuda", "exllama", "exllamav2", "triton", "marlin"]] = None,
precision: Optional[str] = None,
compile: bool = False,
) -> None:
Expand All @@ -129,6 +144,8 @@ def main(
"""
precision = precision or get_default_supported_precision(training=False)

# --------------------- Precision Plugin and Quantization flags ---------------------

plugins = None
use_gptq = False
if quantize is not None and quantize.startswith("bnb."):
Expand All @@ -137,67 +154,70 @@ def main(
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
elif quantize is not None and quantize == "gptq":
elif quantize is not None and quantize.startswith("gptq"):
use_gptq = True
bits = quantize[-1]
if precision != "16-true":
print(
f"AutoGPTQ requires float16 precision, but {precision} was selected. Overriding precision to float16."
)
precision = "16-true"

fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)

check_valid_checkpoint_dir(checkpoint_dir)
# --------------------- Config ---------------------

check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_json(checkpoint_dir / "lit_config.json")

if use_gptq:
from auto_gptq.modeling._base import BaseQuantizeConfig

if not (quantize_config_path := checkpoint_dir / "autogptq_config.json").is_file():
raise ValueError("AutoGPTQ config is missing.")
autogptq_config = json.loads(quantize_config_path.read_text())
kernel_from_config = autogptq_config.pop("kernel")
kernel = kernel or kernel_from_config
quantize_config = BaseQuantizeConfig(**autogptq_config)
model_file = f"lit_model_gptq.{quantize_config.bits}bit.pth"
if not (checkpoint_dir / model_file).is_file():
from quantize.autogptq import QuantizeConfig

quantized_model_dir = checkpoint_dir / f"quantized/{bits}bit"
quantize_config = QuantizeConfig.load_config(quantized_model_dir / "quantize_config.json")
kernel = kernel or quantize_config.kernel
if kernel == "marlin" and quantize_config.marlin_cached:
model_file = "marlin_cache.pth"
else:
model_file = "lit_model_gptq.pth"
if not (quantized_model_dir / model_file).is_file():
raise ValueError(f"`{model_file}` is missing. Please run `python quantize/autogptq.py` first.")
checkpoint_path = quantized_model_dir / model_file
else:
model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / "lit_model.pth"

checkpoint_path = checkpoint_dir / model_file
# --------------------- Tokenizer ---------------------

tokenizer = Tokenizer(checkpoint_dir)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens

# --------------------- Model ---------------------

fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True), torch.device("meta") if use_gptq else nullcontext():
model = GPT(config)

if use_gptq:
from quantize.autogptq import AutoGPTQ

model.config.model_type = None # used in .from_pretrained and .from_quantized
model.config.pad_token_id = None # _prepare_examples_for_quantization
model.config.eos_token_id = tokenizer.eos_id # _prepare_examples_for_quantization
model.config.use_cache = False # for quantization it's disabled anyway
autogptq = AutoGPTQ(model=model, quantized=True, quantize_config=quantize_config)
autogptq.convert_model_to_quantized(kernel)

fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()

if compile:
torch._dynamo.config.automatic_dynamic_shapes = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.coordinate_descent_tuning = True
global next_token
next_token = torch.compile(next_token, mode="reduce-overhead")
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

# --------------------- Conversion to AutoGPTQ ---------------------

if use_gptq:
from quantize.autogptq import AutoGPTQ

autogptq = AutoGPTQ(model=model, quantized=True, quantize_config=quantize_config)
autogptq.convert_to_quantized(kernel, device=fabric.device)

# --------------------- Load State Dict ---------------------

if use_gptq:
state_dict = torch.load(str(checkpoint_path), mmap=True, map_location=fabric.device)
Expand All @@ -208,10 +228,26 @@ def main(
model = fabric.setup_module(model)
load_checkpoint(fabric, model, checkpoint_path)

# --------------------- Convert to Marlin ---------------------

if use_gptq:
if kernel == "marlin":
autogptq.convert_quantized_to_marlin(quantized_model_dir)

# post_init is executed only on a CUDA device
autogptq.post_init()

# --------------------- Final preparations ---------------------

if compile:
torch._dynamo.config.automatic_dynamic_shapes = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.coordinate_descent_tuning = True
global next_token
next_token = torch.compile(next_token, mode="reduce-overhead")

# --------------------- Generation ---------------------

L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
Expand Down
Loading

0 comments on commit 23cd55f

Please sign in to comment.