Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add starcoder2 #1349

Merged
merged 4 commits into from Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
69 changes: 69 additions & 0 deletions examples/starcoder2/qlora.yml
@@ -0,0 +1,69 @@
base_model: bigcode/starcoder2-3b

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca


dataset_prepared_path:
val_set_size: 0.2
output_dir: ./qlora

adapter: qlora
lora_model_dir:

sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 8
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5

train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: false
winglian marked this conversation as resolved.
Show resolved Hide resolved

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 20
evals_per_epoch: 4
eval_steps:
eval_table_size:
saves_per_epoch: 4
save_steps:
save_total_limit: 2
debug:
deepspeed:
weight_decay:
fsdp:
fsdp_config:
special_tokens:
13 changes: 12 additions & 1 deletion src/axolotl/monkeypatch/multipack.py
Expand Up @@ -6,7 +6,14 @@
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data

SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"falcon",
"phi",
"gemma",
"starcoder2",
]


def patch_for_multipack(model_type):
Expand All @@ -32,3 +39,7 @@ def patch_for_multipack(model_type):
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
Comment on lines +42 to +45
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥