Skip to content

[quantization] Add calibrate_seq_len#533

Merged
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:add_calib_seq_len_PR
Mar 4, 2026
Merged

[quantization] Add calibrate_seq_len#533
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:add_calib_seq_len_PR

Conversation

@stamalakhov
Copy link
Copy Markdown
Contributor

@stamalakhov stamalakhov commented Mar 2, 2026

This PR adds calibrate_seq_len to get better results on accuracy.

please see results here

log of `python tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py --model "HuggingFaceTB/SmolLM2-135M-Instruct" --max_seq_len 256 --gptq_mse "--eval_tasks" "winogrande,arc_easy,arc_challenge,openbookqa" ...`

Namespace(model='HuggingFaceTB/SmolLM2-135M-Instruct', device='cuda', dtype='float32', seed=42, trust_remote_code=False, hf_token=None, no_tqdm=False, no_GPTQ=False, no_PTQ=False, save_circle_to_folder='/home/stanislav', cache_dir='/mnt/storage/transformers_cache', nsamples_for_qcalibration=128, linear_weight_bits=4, gptq_mse=True, max_seq_len=256, calibrate_seq_len=2048, embedding_weight_bits=8, lm_head_weight_bits=4, eval_tasks='winogrande,arc_easy,arc_challenge,openbookqa')
=== Config ===
Model            : HuggingFaceTB/SmolLM2-135M-Instruct
Device           : cuda
DType            : float32

Loading FP model …
`torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 272/272 [00:01<00:00, 170.72it/s, Materializing param=model.norm.weight]

Calculating original perplexities …
Token indices sequence length is longer than the specified maximum sequence length for this model (304978 > 8192). Running this sequence through the model will result in indexing errors
PPL: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊| 1191/1192 [01:38<00:00, 12.15it/s]

┌── Wikitext-2 test perplexity ─────────────
│ FP32 :    30.32
└───────────────────────────────────────────
`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way.
Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:03<00:00, 141.97it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1172/1172 [00:12<00:00, 96.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2376/2376 [00:21<00:00, 108.32it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1267/1267 [00:00<00:00, 11541.23it/s]
Running loglikelihood requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18722/18722 [21:08<00:00, 14.76it/s]
Original RESULTS ARE:
|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|
|arc_challenge|      1|none  |     0|acc     |↑  |0.2594|±  |0.0128|
|             |       |none  |     0|acc_norm|↑  |0.2773|±  |0.0131|
|arc_easy     |      1|none  |     0|acc     |↑  |0.5400|±  |0.0102|
|             |       |none  |     0|acc_norm|↑  |0.4882|±  |0.0103|
|openbookqa   |      1|none  |     0|acc     |↑  |0.2240|±  |0.0187|
|             |       |none  |     0|acc_norm|↑  |0.3320|±  |0.0211|
|winogrande   |      1|none  |     0|acc     |↑  |0.5107|±  |0.0140|

Applying GPTQ …
Quantizing layers: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [02:50<00:00,  5.67s/layer]
Wrapping layers with PTQWrapper …                                                                                                                                                                                                                         
Calibrating PTQ obeservers…
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [02:24<00:00,  1.13s/it]

Calculating perplexities …
PPL: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊| 1191/1192 [10:56<00:00,  1.81it/s]

┌── Wikitext-2 test perplexity ─────────────
│ int16 :    52.08
└───────────────────────────────────────────
`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way.
Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:05<00:00, 87.04it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1172/1172 [00:12<00:00, 91.11it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2376/2376 [00:24<00:00, 96.53it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1267/1267 [00:00<00:00, 9694.34it/s]
Running loglikelihood requests: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18722/18722 [1:29:06<00:00,  3.50it/s]
Quantized RESULTS ARE:
|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|
|arc_challenge|      1|none  |     0|acc     |↑  |0.2560|±  |0.0128|
|             |       |none  |     0|acc_norm|↑  |0.2790|±  |0.0131|
|arc_easy     |      1|none  |     0|acc     |↑  |0.4617|±  |0.0102|
|             |       |none  |     0|acc_norm|↑  |0.4470|±  |0.0102|
|openbookqa   |      1|none  |     0|acc     |↑  |0.1960|±  |0.0178|
|             |       |none  |     0|acc_norm|↑  |0.2980|±  |0.0205|
|winogrande   |      1|none  |     0|acc     |↑  |0.5328|±  |0.0140|

saving the whole model to /home/stanislav/model.q.circle

TICO-DCO-1.0-Signed-off-by: s.malakhov s.malakhov@partner.samsung.com

@stamalakhov stamalakhov self-assigned this Mar 2, 2026
@stamalakhov stamalakhov force-pushed the add_calib_seq_len_PR branch from 17f4335 to 7f5f371 Compare March 2, 2026 14:43
@stamalakhov stamalakhov changed the title [quantization] [DRAFT] Add calibrate_seq_len [quantization] [DRAFT] Add calibrate_seq_len Mar 2, 2026
@stamalakhov stamalakhov force-pushed the add_calib_seq_len_PR branch from 7f5f371 to dc07e8a Compare March 2, 2026 16:06
@mhs4670go
Copy link
Copy Markdown
Contributor

This PR adds calibrate_seq_len to get better results on accuracy.

Could you elaborate more about this change?

  1. What this calibrate_seq_len is for?
  2. Why does it make better results?

@stamalakhov
Copy link
Copy Markdown
Contributor Author

This PR adds calibrate_seq_len to get better results on accuracy.

Could you elaborate more about this change?

  1. What this calibrate_seq_len is for?

@mhs4670go
It occured that the more seq_len was used for calibration, we will get the better accuracy. This is important for GPTQ.

  1. Why does it make better results?

IMHO
Seems like calibration with bigger seq_len produces more outliers (especially in inner layers). The calibration set is more challenging, so quantized model performs better afterwards.

Just an example with unsloth/Llama-3.2-3B-Instruct, seq_len == 256:

Config ID PPL arc_easy(%) arc_challenge (%) winogrande (%) openbookqa(%)
FP32 19.05 75 44 67 29
GPTQ_MSE_w4A16 22.9 72 38 64 26
GPTQ_MSE_w4A16_#526_equalization 22.8 72 41 66 27
GPTQ_MSE_w4A16_#533_equalization 22.8 72 41 67 28

So seems like this PR produces better results.

@stamalakhov stamalakhov force-pushed the add_calib_seq_len_PR branch from dc07e8a to e8acbdc Compare March 3, 2026 05:12
@stamalakhov
Copy link
Copy Markdown
Contributor Author

@mhs4670go
I'll provide more samples to show how calibrate_seq_len can improve performance.

@stamalakhov stamalakhov force-pushed the add_calib_seq_len_PR branch 2 times, most recently from c73de94 to 3c29be1 Compare March 3, 2026 11:10
Comment on lines +212 to +213
self._fq(cos[:, : hidden_states.size(1), :], self.obs_cos),
self._fq(sin[:, : hidden_states.size(1), :], self.obs_sin),
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers. It makes it possible to remove padding.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curiosity, is it necessary? Because attetnion_mask masks the padding position.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
Sorry for a lack of details.
attention_mask is prepared for current seq_len as

return self.causal_mask_template[..., :seq_len, :seq_len].to(device)

so in case related causal_mask_template size is larger then seq_len everything is fine (Because it's just upper matrix filled with constant).
We can do the same here:
prepare (rope_cos_template, rope_sin_template) for larger seq_len and then just extract what is needed for current seq_len.
It is assumed that calibrate_seq_len >= max_seq_len.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is necessary for evaluation where tokens are not fixed to 2048 (sequence length to be exported). If we give max_seq length token when we export the model, the slicing will be no-op.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep.

Comment thread tico/utils/convert.py Outdated
Comment on lines +330 to +331
if hasattr(mod, "wrapped"):
mod = mod.wrapped
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers: to remove something like tico.convert(model.wrapped, ...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not necessary. Because PTQWrapper just run wrapper's forward. Did you get an error without this change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh. Yep. I got an error in case "--no_PTQ" option was set.

tico.convert(model.wrapped, ...

it was the cause of crash.
So i transferred the check into convert, but not tested whether it was necessary.
I'll check it.
Thank you!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right. No need in this change. I'll remove it.

@stamalakhov stamalakhov force-pushed the add_calib_seq_len_PR branch from 3c29be1 to 2dc6a98 Compare March 3, 2026 11:32
@stamalakhov
Copy link
Copy Markdown
Contributor Author

stamalakhov commented Mar 3, 2026

Comparison of main branch accuracy with this PR.

For unsloth/Llama-3.2-3B-Instruct, main branch and different calibrate_seq_lens (seq_len == 256):

Config ID PPL arc_easy(%) arc_challenge (%) winogrande (%) openbookqa(%)
FP32 19.05 75 44 67 29
GPTQ_MSE_w4A16_main 22.9 71 37 64 25
GPTQ_MSE_w4A16_#533_equalization_with_2048_calibrate_seq_len 22.8 72 41 67 28
GPTQ_MSE_w4A16_#533_equalization_with_4096_calibrate_seq_len 22.9 71 42 67 27
Llama-3.2-3B-Instruct results from logs

Quantized RESULTS of GPTQ_MSE_w4A16_main:

|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|
|arc_challenge|      1|none  |     0|acc     |↑  |0.3729|±  |0.0141|
|             |       |none  |     0|acc_norm|↑  |0.4002|±  |0.0143|
|arc_easy     |      1|none  |     0|acc     |↑  |0.7071|±  |0.0093|
|             |       |none  |     0|acc_norm|↑  |0.6540|±  |0.0098|
|openbookqa   |      1|none  |     0|acc     |↑  |0.2540|±  |0.0195|
|             |       |none  |     0|acc_norm|↑  |0.3540|±  |0.0214|
|winogrande   |      1|none  |     0|acc     |↑  |0.6433|±  |0.0135|

Quantized RESULTS of GPTQ_MSE_w4A16_#533_equalization_with_2048_calibrate_seq_len:


|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|
|arc_challenge|      1|none  |     0|acc     |↑  |0.4078|±  |0.0144|
|             |       |none  |     0|acc_norm|↑  |0.4386|±  |0.0145|
|arc_easy     |      1|none  |     0|acc     |↑  |0.7226|±  |0.0092|
|             |       |none  |     0|acc_norm|↑  |0.6738|±  |0.0096|
|openbookqa   |      1|none  |     0|acc     |↑  |0.2840|±  |0.0202|
|             |       |none  |     0|acc_norm|↑  |0.3920|±  |0.0219|
|winogrande   |      1|none  |     0|acc     |↑  |0.6661|±  |0.0133|

For unsloth/Llama-3.2-1B-Instruct, main branch and calibrate_seq_len==2048 (seq_len == 256):

Config ID PPL arc_easy(%) arc_challenge (%) winogrande (%) openbookqa(%)
FP32 22.97 69 35 61 27
GPTQ_MSE_w4A16_main 33.50 62 31 56 20
GPTQ_MSE_w4A16_#533_equalization_with_2048_calibrate_seq_len 32.71 64 33 54 23
Llama-3.2-1B-Instruct results from logs

Quantized RESULTS of GPTQ_MSE_w4A16_main:


Quantized RESULTS ARE:
|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|
|arc_challenge|      1|none  |     0|acc     |↑  |0.3055|±  |0.0135|
|             |       |none  |     0|acc_norm|↑  |0.3191|±  |0.0136|
|arc_easy     |      1|none  |     0|acc     |↑  |0.6178|±  |0.0100|
|             |       |none  |     0|acc_norm|↑  |0.5749|±  |0.0101|
|openbookqa   |      1|none  |     0|acc     |↑  |0.1960|±  |0.0178|
|             |       |none  |     0|acc_norm|↑  |0.3140|±  |0.0208|
|winogrande   |      1|none  |     0|acc     |↑  |0.5620|±  |0.0139|

Quantized RESULTS of GPTQ_MSE_w4A16_#533_equalization_with_2048_calibrate_seq_len:


|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|
|arc_challenge|      1|none  |     0|acc     |↑  |0.3285|±  |0.0137|
|             |       |none  |     0|acc_norm|↑  |0.3473|±  |0.0139|
|arc_easy     |      1|none  |     0|acc     |↑  |0.6431|±  |0.0098|
|             |       |none  |     0|acc_norm|↑  |0.6090|±  |0.0100|
|openbookqa   |      1|none  |     0|acc     |↑  |0.2320|±  |0.0189|
|             |       |none  |     0|acc_norm|↑  |0.3480|±  |0.0213|
|winogrande   |      1|none  |     0|acc     |↑  |0.5446|±  |0.0140|

For TinyLlama/TinyLlama-1.1B-Chat-v1.0, main branch and calibrate_seq_len==2048 (seq_len == 256):

Config ID PPL arc_easy(%) arc_challenge (%) winogrande (%) openbookqa(%)
FP32 12.97 62 31 60 25
GPTQ_MSE_w4A16_main 14.50 59 29 60 24
GPTQ_MSE_w4A16_#533_equalization_with_2048_calibrate_seq_len 14.34 59 28 60 25
TinyLlama-1.1B-Chat-v1.0 results from logs

Quantized RESULTS of GPTQ_MSE_w4A16_main:


|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|
|arc_challenge|      1|none  |     0|acc     |↑  |0.2858|±  |0.0132|
|             |       |none  |     0|acc_norm|↑  |0.3020|±  |0.0134|
|arc_easy     |      1|none  |     0|acc     |↑  |0.5863|±  |0.0101|
|             |       |none  |     0|acc_norm|↑  |0.5152|±  |0.0103|
|openbookqa   |      1|none  |     0|acc     |↑  |0.2420|±  |0.0192|
|             |       |none  |     0|acc_norm|↑  |0.3600|±  |0.0215|
|winogrande   |      1|none  |     0|acc     |↑  |0.6006|±  |0.0138|

Quantized RESULTS of GPTQ_MSE_w4A16_#533_equalization_with_2048_calibrate_seq_len:


|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|
|arc_challenge|      1|none  |     0|acc     |↑  |0.2833|±  |0.0132|
|             |       |none  |     0|acc_norm|↑  |0.3251|±  |0.0137|
|arc_easy     |      1|none  |     0|acc     |↑  |0.5880|±  |0.0101|
|             |       |none  |     0|acc_norm|↑  |0.5387|±  |0.0102|
|openbookqa   |      1|none  |     0|acc     |↑  |0.2500|±  |0.0194|
|             |       |none  |     0|acc_norm|↑  |0.3500|±  |0.0214|
|winogrande   |      1|none  |     0|acc     |↑  |0.5967|±  |0.0138|

However for HuggingFaceTB/SmolLM2-135M-Instruct

Config ID PPL arc_easy(%) arc_challenge (%) winogrande (%) openbookqa(%)
FP32 30.32 54 26 51 22
GPTQ_MSE_w4A16_main 50.64 48 26 53 20
GPTQ_MSE_w4A16_#533_equalization_with_2048_calibrate_seq_len 52.08 46 26 53 20
SmolLM2-135M-Instruct results from logs

Quantized RESULTS of GPTQ_MSE_w4A16_main:


|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|
|arc_challenge|      1|none  |     0|acc     |↑  |0.2577|±  |0.0128|
|             |       |none  |     0|acc_norm|↑  |0.2696|±  |0.0130|
|arc_easy     |      1|none  |     0|acc     |↑  |0.4752|±  |0.0102|
|             |       |none  |     0|acc_norm|↑  |0.4474|±  |0.0102|
|openbookqa   |      1|none  |     0|acc     |↑  |0.1960|±  |0.0178|
|             |       |none  |     0|acc_norm|↑  |0.3000|±  |0.0205|
|winogrande   |      1|none  |     0|acc     |↑  |0.5280|±  |0.0140|

Quantized RESULTS of GPTQ_MSE_w4A16_#533_equalization_with_2048_calibrate_seq_len:


|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|
|arc_challenge|      1|none  |     0|acc     |↑  |0.2560|±  |0.0128|
|             |       |none  |     0|acc_norm|↑  |0.2790|±  |0.0131|
|arc_easy     |      1|none  |     0|acc     |↑  |0.4617|±  |0.0102|
|             |       |none  |     0|acc_norm|↑  |0.4470|±  |0.0102|
|openbookqa   |      1|none  |     0|acc     |↑  |0.1960|±  |0.0178|
|             |       |none  |     0|acc_norm|↑  |0.2980|±  |0.0205|
|winogrande   |      1|none  |     0|acc     |↑  |0.5328|±  |0.0140|

So it seems that calibrate_seq_len can improve accuracy for the most of models.

Seems like at least it can provide possibility to improve accuracy for some models.

)
parser.add_argument(
"--max_seq_len",
"--convert_seq_len",
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we can leave it as is (max_seq_len).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. Sorry for bothing you. I think max_seq_len looks easier to understand.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

@stamalakhov stamalakhov changed the title [quantization] [DRAFT] Add calibrate_seq_len [quantization] Add calibrate_seq_len Mar 4, 2026
@stamalakhov stamalakhov marked this pull request as ready for review March 4, 2026 06:06
@stamalakhov stamalakhov force-pushed the add_calib_seq_len_PR branch 2 times, most recently from 81f7be9 to 9799f11 Compare March 4, 2026 07:02
@stamalakhov stamalakhov requested a review from mhs4670go March 4, 2026 07:04
config = q_m.config

orig_seq_len = config.max_position_embeddings
config.max_position_embeddings = args.max_seq_len
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of changes can be removed.

First, I think we should fix evaluate_llm_on_tasks api.

def evaluate_llm_on_tasks(
    model: AutoModelForCausalLM, tokenizer: AutoTokenizer, tasks: str
) -> dict[str, Any]:
    model_to_evaluate = HFLM(model, "causal", tokenizer=tokenizer)
    tasks_list: list[str] = tasks.split(",")
    return evaluator.simple_evaluate(model_to_evaluate, tasks=tasks_list)

Since the accelerator has a fixed maximum sequence length, it is better to match the max sequence length during evaluation as well.

Therefore, the code will be:

def evaluate_llm_on_tasks(
    model: AutoModelForCausalLM, tokenizer: AutoTokenizer, tasks: str, max_length: int
) -> dict[str, Any]:
  # ..
  model_to_evaluate = HFLM(
      model,
      "causal",
      tokenizer=tokenizer,
      max_length=max_length,
      truncation=True,
    )

Then, you don't have to change max_position_embeddings itself repeatedly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh. Ok. I'll try.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
Done. At least ppl is fine for SmolLM. I'll check for eval_tasks (its around 1.5 hour)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything is fine. the same results.

@stamalakhov stamalakhov force-pushed the add_calib_seq_len_PR branch from 9799f11 to 1d46dd2 Compare March 4, 2026 08:00
This PR adds `calibrate_seq_len` to get better accuracy and adjusts relevant code accordingly.

TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
@stamalakhov stamalakhov force-pushed the add_calib_seq_len_PR branch from 1d46dd2 to d8d7a88 Compare March 4, 2026 08:07
Copy link
Copy Markdown
Contributor

@mhs4670go mhs4670go left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mhs4670go mhs4670go merged commit f60e3d3 into Samsung:main Mar 4, 2026
7 checks passed
@stamalakhov stamalakhov deleted the add_calib_seq_len_PR branch March 4, 2026 08:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants