-
Notifications
You must be signed in to change notification settings - Fork 797
/
lora.py
147 lines (125 loc) · 5.68 KB
/
lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import sys
import time
from pathlib import Path
from typing import Literal, Optional
import lightning as L
import torch
from lightning.fabric.plugins import BitsandbytesPrecision
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from generate.base import generate
from lit_gpt import Tokenizer
from lit_gpt.lora import GPT, Config, merge_lora_weights
from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, gptq_quantization, lazy_load
from scripts.prepare_alpaca import generate_prompt
lora_r = 8
lora_alpha = 16
lora_dropout = 0.05
lora_query = True
lora_key = False
lora_value = True
lora_projection = False
lora_mlp = False
lora_head = False
def main(
prompt: str = "What food do llamas eat?",
input: str = "",
lora_path: Path = Path("out/lora/alpaca/lit_model_lora_finetuned.pth"),
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]] = None,
max_new_tokens: int = 100,
top_k: Optional[int] = 200,
temperature: float = 0.8,
precision: Optional[str] = None,
) -> None:
"""Generates a response based on a given instruction and an optional input.
This script will only work with checkpoints from the instruction-tuned GPT-LoRA model.
See `finetune/lora.py`.
Args:
prompt: The prompt/instruction (Alpaca style).
input: Optional input (Alpaca style).
lora_path: Path to the checkpoint with trained adapter weights, which are the output of
`finetune/lora.py`.
checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights.
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
- gptq.int4: 4-bit quantization from GPTQ
for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
precision: Indicates the Fabric precision setting to use.
"""
precision = precision or get_default_supported_precision(training=False)
plugins = None
if quantize is not None and quantize.startswith("bnb."):
if "mixed" in precision:
raise ValueError("Quantization and mixed precision is not supported.")
dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None
fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)
fabric.launch()
check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_json(
checkpoint_dir / "lit_config.json",
r=lora_r,
alpha=lora_alpha,
dropout=lora_dropout,
to_query=lora_query,
to_key=lora_key,
to_value=lora_value,
to_projection=lora_projection,
to_mlp=lora_mlp,
to_head=lora_head,
)
if quantize == "gptq.int4":
model_file = "lit_model_gptq.4bit.pth"
if not (checkpoint_dir / model_file).is_file():
raise ValueError("Please run `python quantize/gptq.py` first")
else:
model_file = "lit_model.pth"
checkpoint_path = checkpoint_dir / model_file
tokenizer = Tokenizer(checkpoint_dir)
sample = {"instruction": prompt, "input": input}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, device=fabric.device)
prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens
fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()
t0 = time.perf_counter()
checkpoint = lazy_load(checkpoint_path)
lora_checkpoint = lazy_load(lora_path)
checkpoint.update(lora_checkpoint.get("model", lora_checkpoint))
model.load_state_dict(checkpoint)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
merge_lora_weights(model)
model = fabric.setup(model)
L.seed_everything(1234)
t0 = time.perf_counter()
y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
t = time.perf_counter() - t0
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
fabric.print(output)
tokens_generated = y.size(0) - prompt_length
fabric.print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
if __name__ == "__main__":
from jsonargparse import CLI
torch.set_float32_matmul_precision("high")
CLI(main)