Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions tico/quantization/evaluation/script/llm_tasks_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,20 @@


def evaluate_llm_on_tasks(
model, tokenizer: AutoTokenizer, tasks: str
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
tasks: str,
max_length: int | None = None,
) -> dict[str, Any]:
if hasattr(model, "wrapped"):
model = model.wrapped
model_to_evaluate = HFLM(model, "causal", tokenizer=tokenizer)
model_to_evaluate = HFLM(
model,
"causal",
tokenizer=tokenizer,
max_length=max_length,
truncation=True,
)
tasks_list: list[str] = tasks.split(",")
return evaluator.simple_evaluate(model_to_evaluate, tasks=tasks_list)

Expand Down
29 changes: 21 additions & 8 deletions tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def save_circles_to(q_m, calib_inputs, save_circle_to_folder):
print(f"saving the whole model to {save_path.resolve()}")
with torch.no_grad():
with SuppressWarning(UserWarning, ".*"):
cm = tico.convert(q_m.wrapped, (calib_inputs[0],), strict=False)
cm = tico.convert(q_m, (calib_inputs[0],), strict=False)

cm.save(save_path)

Expand Down Expand Up @@ -238,14 +238,18 @@ def evaluate(q_m, tokenizer, dataset_test, args):
# -------------------------------------------------------------------------
print("\nCalculating perplexities …")
enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
ppl_uint8 = perplexity(q_m, enc, args.device, stride=args.max_seq_len)
ppl_uint8 = perplexity(
q_m, enc, args.device, max_length=args.max_seq_len, stride=args.max_seq_len
)

print("\n┌── Wikitext-2 test perplexity ─────────────")
print(f"│ int16 : {ppl_uint8:8.2f}")
print("└───────────────────────────────────────────")

if args.eval_tasks is not None:
results = evaluate_llm_on_tasks(q_m, tokenizer, args.eval_tasks)
results = evaluate_llm_on_tasks(
q_m, tokenizer, args.eval_tasks, max_length=args.max_seq_len
)
print("Quantized RESULTS ARE:")
print(make_table(results))

Expand Down Expand Up @@ -330,7 +334,13 @@ def main():
"--max_seq_len",
type=int,
default=None,
help="constraint for max_position_embeddings",
help="seq_len to use in model evaluation and conversion to circle",
)
parser.add_argument(
"--calibrate_seq_len",
type=int,
default=2048,
help="seq_len to use in quantized model calibration. More the better",
)
parser.add_argument(
"--embedding_weight_bits",
Expand Down Expand Up @@ -387,9 +397,9 @@ def main():
)

model.config.use_cache = False # TODO use args for it
if args.max_seq_len is not None:
if args.calibrate_seq_len is not None:
model.config.max_position_embeddings = min(
model.config.max_position_embeddings, args.max_seq_len
model.config.max_position_embeddings, args.calibrate_seq_len
)

dataset_test = load_dataset(
Expand All @@ -399,15 +409,17 @@ def main():
print("\nCalculating original perplexities …")
enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
ppl_fp32 = perplexity(
model, enc, device, stride=model.config.max_position_embeddings
model, enc, device, max_length=args.max_seq_len, stride=args.max_seq_len
)

print("\n┌── Wikitext-2 test perplexity ─────────────")
print(f"│ FP32 : {ppl_fp32:8.2f}")
print("└───────────────────────────────────────────")

if args.eval_tasks is not None:
results = evaluate_llm_on_tasks(model, tokenizer, args.eval_tasks)
results = evaluate_llm_on_tasks(
model, tokenizer, args.eval_tasks, max_length=args.max_seq_len
)
print("Original RESULTS ARE:")
print(make_table(results))

Expand Down Expand Up @@ -456,6 +468,7 @@ def main():
evaluate(q_m, tokenizer, dataset_test, args)

if args.save_circle_to_folder is not None:
calib_inputs = list(torch.stack(calib_inputs).reshape(-1, 1, args.max_seq_len))
save_circles_to(q_m, calib_inputs, args.save_circle_to_folder)


Expand Down
4 changes: 2 additions & 2 deletions tico/quantization/wrapq/wrappers/llama/quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def forward(
position_embeddings = self.get_position_embeddings_for(hidden_states)
cos, sin = position_embeddings
position_embeddings = (
self._fq(cos, self.obs_cos),
self._fq(sin, self.obs_sin),
self._fq(cos[:, : hidden_states.size(1), :], self.obs_cos),
self._fq(sin[:, : hidden_states.size(1), :], self.obs_sin),
Comment on lines +212 to +213
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers. It makes it possible to remove padding.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curiosity, is it necessary? Because attetnion_mask masks the padding position.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
Sorry for a lack of details.
attention_mask is prepared for current seq_len as

return self.causal_mask_template[..., :seq_len, :seq_len].to(device)

so in case related causal_mask_template size is larger then seq_len everything is fine (Because it's just upper matrix filled with constant).
We can do the same here:
prepare (rope_cos_template, rope_sin_template) for larger seq_len and then just extract what is needed for current seq_len.
It is assumed that calibrate_seq_len >= max_seq_len.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is necessary for evaluation where tokens are not fixed to 2048 (sequence length to be exported). If we give max_seq length token when we export the model, the slicing will be no-op.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep.

)

# decoder layers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,6 @@
from tico.quantization.wrapq.wrappers.registry import try_register


def fix_inputs(config, pad_token_id, input_ids):
pads = torch.full(
(
input_ids.shape[0],
config.max_position_embeddings - input_ids.shape[1],
),
fill_value=pad_token_id,
device=input_ids.device,
)

return torch.cat((input_ids, pads), dim=1)


@try_register("transformers.models.llama.modeling_llama.LlamaForCausalLM")
class QuantLlamaForCausalLM(QuantModuleBase):
def __init__(
Expand Down Expand Up @@ -90,16 +77,6 @@ def forward(
logits_to_keep: int | torch.Tensor = 0,
**kwargs,
) -> CausalLMOutputWithPast:
orig_len = input_ids.shape[-1] # type: ignore[union-attr]
pad_id = (
self.config.pad_token_id
if getattr(self.config, "pad_token_id", None) is not None
else self.config.eos_token_id
)

input_ids = fix_inputs(self.config, pad_id, input_ids)
if labels is not None:
labels = fix_inputs(self.config, pad_id, labels)

output_attentions = self.config.output_attentions
output_hidden_states = self.config.output_hidden_states
Expand Down Expand Up @@ -128,9 +105,6 @@ def forward(
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
logits = logits[..., :orig_len, :]
if labels is not None:
labels = labels[..., :orig_len]

loss = None
if labels is not None:
Expand Down