Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproduce perplexity #49

Closed
deciding opened this issue Mar 11, 2024 · 6 comments
Closed

Reproduce perplexity #49

deciding opened this issue Mar 11, 2024 · 6 comments

Comments

@deciding
Copy link

In the readme the ppl is

Llama-2-7b | 1x16 | 5.92 | 2.4

In the paper it is:

Llama-2-7b AQLM 2.29 6.29 8.11

When I run locally using the same command as in the readme

CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py $MODEL_PATH $DATASET_PATH --nsamples=1024 \
 --num_codebooks=1 --nbits_per_codebook=16 --in_group_size=8 \
 --relative_mse_tolerance=0.01 --finetune_relative_mse_tolerance=0.001 \
 --finetune_batch_size=32 --local_batch_size=1 --offload_activations \
 --wandb --save $SAVE_PATH

it gives me

Llama-2-7b AQLM 2.29 6.45 8.39

Can I know why there is such a mismatch? Thanks for any clarifications.

@Vahe1994
Copy link
Owner

Vahe1994 commented Mar 12, 2024

Hi!
There are 2 different factors contributing to the mismatch.

  1. I believe the difference between your results and those in the paper are mainly due to difference in hyperparameters. The reported result in the paper were achieved using the following settings: --nsamples=1024 --num_codebooks=1 --nbits_per_codebook=16 --in_group_size=8 --relative_mse_tolerance=0.01 --finetune_lr=1e-5 --finetune_adam_beta1=0.90 --finetune_adam_beta2=0.95 --finetune_keep_best --finetune_relative_mse_tolerance=0.001 --finetune_batch_size=32 --local_batch_size=4 --save save_path --wandb. Additionally, results may slightly vary from run to run due to randomness. For more details, please refer to Table 8 in the paper's appendix.

  2. The result of 5.92 from ReadMe/HF was achieved through full fine-tuning on top of the obtained quantization please see Appendix A and Global finetuning? #30 . The code for fine-tuning can be found in Added global funetuning & validation loss early stopping & gemma support #50.

Hope this helps. If you have any additional questions, please feel free to ask.

@deciding
Copy link
Author

Very clear, thx so much. I will try to reproduce it.

@Godofnothing
Copy link
Collaborator

@deciding The current Llama-2-7b checkpoint with wikitext2 ppl=5.91 was obtained as follows.

Quantization with blockwise finetuning yields 6.22 ppl. Compared to the version in the main branch it has early stopping on a validation set. The run script (with main.py) used the following hyperparameters.

python main.py \
    $MODEL_PATH \
    $DATASET_PATH \
    --nsamples=2048 \
    --val_size=256 \
    --model_seqlen=4096 \
    --num_codebooks=1 \
    --nbits_per_codebook=16 \
    --in_group_size=8 \
    --out_group_size=1 \
    --relative_mse_tolerance=0.01 \
    --finetune_lr=1e-4 \
    --finetune_adam_beta1=0.90 \
    --finetune_adam_beta2=0.999 \
    --finetune_keep_best \
    --finetune_batch_size=8 \
    --finetune_max_epochs=20 \
    --finetune_early_stop=3 \
    --local_batch_size=4 \
    --offload_activations

The final model was obtained via end-to-end finetuning (script finetune.py) from the model above with the following hyperparameters:

python finetune.py \
  --base_model $MODEL_PATH \
  --quant_model $INPUT_PATH \
  --dataset $DATASET_PATH \
  --nsamples=1024 \
  --val_size=256 \
  --lr=1e-5 \
  --adam_beta1=0.90 \
  --adam_beta2=0.999 \
  --epochs=5 \
  --early_stop=3 \
  --batch_size=8 \
  --microbatch_size=4 \
  \
  --temperature=1.0 \
  \
  --save $DATA_PATH \
  \
  --gradient_checkpointing

@deciding
Copy link
Author

@Godofnothing really appreciate the tuning details! Besides, may I know the number of a100 GPU hours required for this finetune script?

@Godofnothing
Copy link
Collaborator

@deciding I do not remember exact numbers, I think the first part took 1 day on 2 A 100 and the second one 6 hours on single A100

@deciding
Copy link
Author

@Godofnothing Cool. Thx a lot for the information 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants