Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configure PEFT from config #3571

Merged
merged 7 commits into from Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion model/model_training/configs/config.yaml
Expand Up @@ -2,6 +2,7 @@ defaults:
rng_seed: 0xa1221f97
learning_rate: 1e-5
gradient_checkpointing: false
int8_training: false
gradient_accumulation_steps: 32
per_device_train_batch_size: 2
per_device_eval_batch_size: 2
Expand Down Expand Up @@ -803,8 +804,12 @@ rope_scaling_test:
residual_dropout_lima: true
log_wandb: true
peft_model: true
peft_type: "lora"
peft_config:
peft_type: "lora"
r: 16
superhot: true
superhot_config:
type: linear
scale: 2
datasets:
- dolly15k
57 changes: 37 additions & 20 deletions model/model_training/models/peft_modeling.py
Expand Up @@ -3,7 +3,7 @@

import torch
from huggingface_hub import hf_hub_download
from model_training.utils.utils import get_model, get_tokenizer
from model_training.utils.utils import get_all_linear_layers, get_model, get_tokenizer, merge_dicts
from peft import LoraConfig, PeftModel, PrefixTuningConfig, get_peft_model, prepare_model_for_int8_training


Expand All @@ -18,11 +18,15 @@ def load_peft_model(model, peft_model_path, tokenizer):
torch_dtype=model.dtype,
)
model.eos_token_id = tokenizer.eos_token_id
extra_embeds = hf_hub_download(peft_model_path, "extra_embeddings.pt")
embed_weights = torch.load(extra_embeds, map_location=model.device)
model.base_model.model.model.embed_tokens.weight[len(tokenizer) - embed_weights.shape[0] :, :] = embed_weights.to(
model.base_model.model.model.embed_tokens.weight.dtype
)
try:
extra_embeds = hf_hub_download(peft_model_path, "extra_embeddings.pt")
embed_weights = torch.load(extra_embeds, map_location=model.device)
model.base_model.model.model.embed_tokens.weight[
len(tokenizer) - embed_weights.shape[0] :, :
] = embed_weights.to(model.base_model.model.model.embed_tokens.weight.dtype)
except Exception:
print("Warning:Extra embeddings not added. This is expected if adapter file contains WTE")

return model


Expand All @@ -42,27 +46,40 @@ def make_inputs_require_grad(module, input, output):
return model


def peft_model(model, peft_type="lora", int8_training=False, gradient_checkpointing=False):
def peft_model(model, training_config):
peft_config = training_config.peft_config
peft_type = peft_config.pop("peft_type", "lora")
if peft_type == "lora":
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
default_args = {
"r": 16,
"lora_alpha": 32,
"target_modules": "all",
"lora_dropout": 0.05,
"bias": "none",
"task_type": "CAUSAL_LM",
"modules_to_save": ["wte", "lm_head"],
}
kwargs = merge_dicts(default_args, peft_config)
if kwargs.get("target_modules") == "all":
kwargs.update({"target_modules": get_all_linear_layers(model)})
config = LoraConfig(**kwargs)
elif peft_type == "prefix-tuning":
config = PrefixTuningConfig(
num_virtual_tokens=30, prefix_projection=True, encoder_hidden_size=1024, task_type="CAUSAL_LM"
)
default_args = {
"num_virtual_tokens": 30,
"prefix_projection": True,
"encoder_hidden_size": 1024,
"task_type": "CAUSAL_LM",
}
kwargs = merge_dicts(default_args, peft_config)
config = PrefixTuningConfig(**kwargs)
else:
raise ValueError("peft_method config is lora or prefix-tuning")
model = get_peft_model(model, config)
if int8_training:

if training_config.int8_training:
model = prepare_model_for_int8_training(model)

if gradient_checkpointing:
if training_config.gradient_checkpointing:
model = prepare_model_for_gradient_checkpointing(model)
model.print_trainable_parameters()
return model
Expand Down
4 changes: 1 addition & 3 deletions model/model_training/trainer_sft.py
Expand Up @@ -424,9 +424,7 @@ def main():

if training_conf.peft_model:
print("Using PEFT model")
model = peft_model(
model, peft_type=training_conf.peft_type, gradient_checkpointing=training_conf.gradient_checkpointing
)
model = peft_model(model, training_conf)

if training_conf.quantization:
import bitsandbytes # This is noisy, so delay importing until after argument parsing so it doesn't make --help noisy
Expand Down
21 changes: 21 additions & 0 deletions model/model_training/utils/utils.py
Expand Up @@ -432,3 +432,24 @@ def process_output(output: str, method: str = "v2", bot_name: str = "Joi") -> st
answer = output.split("\n\n{}:".format(bot_name))[-1]
answer = answer.split("</s>")[0].replace("<|endoftext|>", "").lstrip().split("\n\n{}:".format(bot_name))[0]
return answer


def merge_dicts(default: dict, config: dict):
"""
merge default dict with config dict to override params
"""
for k, v in default.items():
if k not in config.keys():
config.update({k: v})

return config


def get_all_linear_layers(model):
cls = torch.nn.Linear

modules = {name.split(".")[-1] for name, module in model.named_modules() if isinstance(module, cls)}
if "lm_head" in modules:
modules.remove("lm_head")

return list(modules)