Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added global funetuning & validation loss early stopping & gemma support #50

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,20 +160,62 @@ Main CLI arguments:
- `MODEL_PATH` - a path to either hugginface hub (e.g. meta-llama/Llama-2-7b-hf) or a local folder with transformers model and a tokenizer.
- `DATASET_PATH` - either a path to calibration data (see above) or a standard dataset `[c4, ptb, wikitext2]`
- for llama-2 models, you can use `DATASET_PATH=./data/red_pajama_n=1024_4096_context_length.pth` for a slice of RedPajama (up to 1024 samples)
- `--nsamples` - the number of calibration data _sequences_. If this parameter is not set, take all calibration data avaialble.
- `--nsamples` - the number of calibration data _sequences_ (train + validation). If this parameter is not set, take all calibration data avaialble.
- `--val_size` - the number of validation sequences for early stopping on block finetuning. By default equal to 0. Must be smaller than `--nsamples`.
- `--num_codebooks` - number of codebooks per layer
- `--nbits_per_codebook` - each codebook will contain 2 ** nbits_per_codebook vectors
- `--in_group_size` - how many weights are quantized together (aka "g" in the arXiv paper)
- `--finetune_batch_size` - (for fine-tuning only) the total number of sequences used for each optimization step
- `--local_batch_size` - when accumulating finetune_batch_size, process this many samples per GPU per forward pass (affects GPU RAM usage)
- `--relative_mse_tolerance`- (for initial calibration) - stop training when (current_epoch_mse / previous_epoch_mse) > (1 - relative_mse_tolerance)
- `--finetune_relative_mse_tolerance`- same, but for fine-tuning
- `--finetune_max_epochs` - maximal number of passes through calibration data on block tuning.
- `"--finetune_early_stop` - maximal number of passes through calibration data without improvement on validation.
- `--offload_activations` -- during calibration, move activations from GPU memory to RAM. This reduces VRAM usage while slowing calibration by ~10% (depending on your hardware).
- `--save` -- path to save/load quantized model. (see also: `--load`)
- `--wandb` - if this parameter is set, the code will log results to wandb
- `--attn_implementation` - specify attention (for transformers >= `4.38`). Sdpa attention sometimes causes issues and it is recommended to use `eager` implementation.

There are additional hyperparameters aviailable. Run `python main.py --help` for more details on command line arguments, including compression parameters.

### Finetuning

The accuracy of the quantized model can be further improved via block finetuning. First, the logits
of the float16/bfloat16 are cached in RAM. Then the differentiable parameters of the quantized model
are optimized to minimize KL-divergence with teacher logits. Typically, we use the same calibration data that was used for model quantization.

The command to launch the script should look like this:

```bash
python finetune.py \
--base_model $MODEL_PATH \
--quant_model $INPUT_PATH \
--dataset $DATASET_PATH \
--nsamples=<TOTAL_SIZE> \
--val_size=<VAL_SIZE> \
--lr=1e-5 \
--adam_beta1=0.90 \
--adam_beta2=0.999 \
--epochs=5 \
--early_stop=3 \
--batch_size=8 \
--microbatch_size=4 \
\
--temperature=1.0 \
\
--save $DATA_PATH \
\
--gradient_checkpointing
```

Main CLI arguments:
- `--base_model` - path or name of the original floating-point model
- `--quant_model` - path to quantized model weights.
- `--dataset` - path or name of the calibration dataset
- `--nsamples` - the number of calibration data _sequences_ (train + validation). If this parameter is not set, take all calibration data avaialble.
- `--val_size` - the number of validation sequences for early stopping on block finetuning. By default equal to 0. Must be smaller than `--nsamples`.

**Note** for larger models one would need multi-GPU training. At the moment, FSDP training is not implemented and the model is finetuned on a single process with parameters sharded across available devices.

### Zero-shot benchmarks via LM Evaluation Harness

To perform zero-shot evaluation, we use [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) framework with slight modifications. This repository contains a copy of LM Evaluation Harness repo from early 2023 in `lm-eval-harness` folder.
Expand Down
26 changes: 23 additions & 3 deletions convert_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

import torch
from tqdm.auto import trange
from transformers import AutoConfig, PretrainedConfig
from transformers import AutoConfig, AutoModelForCausalLM

try:
import safetensors
except ModuleNotFoundError:
safetensors = None


def get_int_dtype(nbits: int) -> torch.dtype:
Expand All @@ -28,15 +33,15 @@ def pack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:

def get_num_layers(config) -> int:
match config.model_type:
case "llama" | "mistral" | "mixtral":
case "llama" | "mistral" | "mixtral" | "gemma":
return config.num_hidden_layers
case unknown_type:
raise NotImplementedError(f"Can't get number of layers for {unknown_type}")


def get_layers_prefix(config) -> str:
match config.model_type:
case "llama" | "mistral" | "mixtral":
case "llama" | "mistral" | "mixtral" | "gemma":
return "model.layers"
case unknown_type:
raise NotImplementedError(f"Can't get layers prefix for {unknown_type}")
Expand Down Expand Up @@ -66,6 +71,9 @@ def get_converted_state_dict(config, nbits: int, in_path: os.PathLike) -> [dict,
state_dict[key] = value.half()
linear_weights_not_to_quantize.append(key)

if "lm_head.weight" not in linear_weights_not_to_quantize:
linear_weights_not_to_quantize.append("lm_head.weight")

return state_dict, linear_weights_not_to_quantize


Expand Down Expand Up @@ -119,6 +127,11 @@ def add_inference_code(model_type: str, save_path: os.PathLike):
type=str,
help="Path to save HF compatible checkpoint to",
)
parser.add_argument(
"--save_safetensors",
action="store_true",
help="Whether to save in safetensors format",
)
args = parser.parse_args()

old_config = AutoConfig.from_pretrained(args.model)
Expand All @@ -132,3 +145,10 @@ def add_inference_code(model_type: str, save_path: os.PathLike):
new_config_dict = update_config(old_config.to_diff_dict(), metadata, linear_weights_not_to_quantize)
with open(os.path.join(args.out_path, "config.json"), "w") as config_file:
json.dump(new_config_dict, config_file, indent=4)

# convert to safetensors
if args.save_safetensors:
assert safetensors
model = AutoModelForCausalLM.from_pretrained(args.out_path, trust_remote_code=True, torch_dtype=torch.float16)
shutil.rmtree(args.out_path)
model.save_pretrained(args.out_path)
Loading
Loading