Skip to content

Commit

Permalink
correct load model sft
Browse files Browse the repository at this point in the history
  • Loading branch information
sanagno committed Feb 20, 2023
1 parent eadf3e5 commit 6e081a0
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 4 deletions.
15 changes: 14 additions & 1 deletion model/model_training/configs/config.yaml
Expand Up @@ -50,16 +50,28 @@ defaults:
log_wandb: true
samples_mixing: false # uses collator that mixes samples in the batch to create a single sample with possible multiple tasks within
verbose: false
output_dir: saved_model

oa_dataset_only:
datasets:
- oa_private:
data_path: .cache
split: sft
val_split: 0.0
fraction: 1
file: 2023-02-10_oasst_prod.jsonl

pythia:
learning_rate: 8e-6
model_name: EleutherAI/pythia-70m-deduped
weight_decay: 0.01
max_length: 520
warmup_steps: 1000
gradient_checkpointing: false
gradient_accumulation_steps: 9
per_device_train_batch_size: 2
per_device_eval_batch_size: 4
output_dir: pythia_model

galactica-125m:
learning_rate: 5e-5
model_name: facebook/galactica-125m
Expand Down Expand Up @@ -103,3 +115,4 @@ debug:
quantization: false
log_wandb: false
verbose: true
num_train_epochs: 0.2
2 changes: 1 addition & 1 deletion model/model_training/models/__init__.py
Expand Up @@ -25,7 +25,7 @@ def freeze_top_n_layers(model, target_layers):
return model


def get_specific_model(model_name, seq2seqmodel=False, cache_dir=".cache", **kwargs):
def get_specific_model(model_name, seq2seqmodel=False, cache_dir=".cache", quantization=False, **kwargs):
# encoder-decoder support for Flan-T5 like models
# for now, we can use an argument but in the future,
# we can automate this
Expand Down
8 changes: 7 additions & 1 deletion model/model_training/trainer_sft.py
Expand Up @@ -221,8 +221,14 @@ def argument_parsing(notebook=False, notebook_args=None):
if training_conf.fuse_gelu:
model = fuse_gelu(model)

output_dir = (
training_conf.output_dir
if training_conf.output_dir
else f"{training_conf.model_name}-{training_conf.log_dir}-finetuned"
)

args = TrainingArguments(
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
output_dir=output_dir,
num_train_epochs=training_conf.num_train_epochs,
warmup_steps=training_conf.warmup_steps,
learning_rate=float(training_conf.learning_rate),
Expand Down
4 changes: 3 additions & 1 deletion model/model_training/utils.py
Expand Up @@ -214,7 +214,9 @@ def get_metrics(conf, tokenizer):


def get_model(conf, tokenizer):
model = get_specific_model(conf.model_name, conf.cache_dir, conf.quantization, conf.seq2seqmodel)
model = get_specific_model(
conf.model_name, cache_dir=conf.cache_dir, quantization=conf.quantization, seq2seqmodel=conf.seq2seqmodel
)

if len(tokenizer) != model.get_input_embeddings().num_embeddings:
assert not conf.freeze_layer, "Cannot change the number of embeddings if the model is frozen."
Expand Down

0 comments on commit 6e081a0

Please sign in to comment.