diff --git a/aixplain/modules/finetune/hyperparameters.py b/aixplain/modules/finetune/hyperparameters.py index 3a68a9d7..51dc9842 100644 --- a/aixplain/modules/finetune/hyperparameters.py +++ b/aixplain/modules/finetune/hyperparameters.py @@ -2,6 +2,17 @@ from dataclasses_json import dataclass_json +class SchedulerType: + LINEAR = "linear" + COSINE = "cosine" + COSINE_WITH_RESTARTS = "cosine_with_restarts" + POLYNOMIAL = "polynomial" + CONSTANT = "constant" + CONSTANT_WITH_WARMUP = "constant_with_warmup" + INVERSE_SQRT = "inverse_sqrt" + REDUCE_ON_PLATEAU = "reduce_lr_on_plateau" + + @dataclass_json @dataclass class Hyperparameters(object): @@ -9,9 +20,12 @@ class Hyperparameters(object): train_batch_size: int = 4 eval_batch_size: int = 4 learning_rate: float = 2e-5 - warmup_steps: int = 500 generation_max_length: int = 225 tokenizer_batch_size: int = 256 gradient_checkpointing: bool = False gradient_accumulation_steps: int = 1 max_seq_length: int = 4096 + warmup_ratio: float = 0.0 + warmup_steps: int = 0 + early_stopping_patience: int = 1 + lr_scheduler_type: SchedulerType = SchedulerType.LINEAR diff --git a/tests/functional/finetune/finetune_functional_test.py b/tests/functional/finetune/finetune_functional_test.py index ef04e26b..8cdbc77c 100644 --- a/tests/functional/finetune/finetune_functional_test.py +++ b/tests/functional/finetune/finetune_functional_test.py @@ -55,7 +55,7 @@ def validate_prompt_input_map(request): def test_end2end_text_generation(run_input_map): - model = ModelFactory.list(query=run_input_map["model_name"], is_finetunable=True)["results"][0] + model = ModelFactory.get(run_input_map["model_id"]) dataset_list = [DatasetFactory.list(query=run_input_map["dataset_name"])["results"][0]] train_percentage, dev_percentage = 100, 0 if run_input_map["required_dev"]: @@ -74,6 +74,7 @@ def test_end2end_text_generation(run_input_map): while status != "onboarded" and (end - start) < TIMEOUT: status = finetune_model.check_finetune_status() assert status != "failed" + time.sleep(5) end = time.time() assert finetune_model.check_finetune_status() == "onboarded" result = finetune_model.run(run_input_map["inference_data"])