diff --git a/model/model_training/models/peft_modeling.py b/model/model_training/models/peft_modeling.py index 2355f541ab..4c796287b0 100644 --- a/model/model_training/models/peft_modeling.py +++ b/model/model_training/models/peft_modeling.py @@ -57,7 +57,6 @@ def peft_model(model, training_config): "lora_dropout": 0.05, "bias": "none", "task_type": "CAUSAL_LM", - "modules_to_save": ["wte", "lm_head"], } kwargs = merge_dicts(default_args, peft_config) if kwargs.get("target_modules") == "all": diff --git a/model/model_training/trainer_sft.py b/model/model_training/trainer_sft.py index 4c4c820999..56797a8e7f 100755 --- a/model/model_training/trainer_sft.py +++ b/model/model_training/trainer_sft.py @@ -27,7 +27,7 @@ from torch import nn from torch.utils.data import DataLoader, Subset from tqdm import tqdm -from transformers import PreTrainedModel, Trainer, TrainingArguments +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, Trainer, TrainingArguments from transformers.trainer_pt_utils import IterableDatasetShard from transformers.trainer_utils import seed_worker from transformers.training_args import OptimizerNames @@ -327,7 +327,10 @@ def main(): init_rng(training_conf) - tokenizer = get_tokenizer(training_conf) + if training_conf.peft_model: + tokenizer = AutoTokenizer.from_pretrained(training_conf.model_name) + else: + tokenizer = get_tokenizer(training_conf) if not training_conf.deepspeed or training_conf.local_rank == 0: tokenizer_sanity_check(tokenizer) @@ -416,7 +419,13 @@ def main(): sampler = None metrics, preprocess_fns = get_metrics(training_conf, tokenizer) - model = get_model(training_conf, tokenizer) + if training_conf.peft_model: + logging.warning("PEFT model: make sure this is an adapted base model which has added special tokens!") + model = AutoModelForCausalLM.from_pretrained( + training_conf.model_name, torch_dtype=torch.bfloat16 if training_conf.dtype == "bf16" else torch.float16 + ) + else: + model = get_model(training_conf, tokenizer) superhot = RopePatch.from_config(training_conf) if training_conf.superhot else None if superhot: