diff --git a/tests/functional/finetune/data/finetune_test_cost_estimation.json b/tests/functional/finetune/data/finetune_test_cost_estimation.json new file mode 100644 index 00000000..a12ccdfb --- /dev/null +++ b/tests/functional/finetune/data/finetune_test_cost_estimation.json @@ -0,0 +1,11 @@ +[ + {"model_name": "gpt2", "model_id": "64e615671567f848804985e1", "dataset_name": "Test text generation dataset"}, + {"model_name": "falcon 7b instruct", "model_id": "65519d57bf42e6037ab109d5", "dataset_name": "Test text generation dataset"}, + {"model_name": "bloomz 7b", "model_id": "6551ab17bf42e6037ab109e0", "dataset_name": "Test text generation dataset"}, + {"model_name": "MPT 7B", "model_id": "6551a72bbf42e6037ab109d9", "dataset_name": "Test text generation dataset"}, + {"model_name": "falcon 7b", "model_id": "6551bff9bf42e6037ab109e1", "dataset_name": "Test text generation dataset"}, + {"model_name": "mistral 7b", "model_id": "6551a9e7bf42e6037ab109de", "dataset_name": "Test text generation dataset"}, + {"model_name": "MPT 7B Storywriter", "model_id": "6551a870bf42e6037ab109db", "dataset_name": "Test text generation dataset"}, + {"model_name": "llama 2 7b", "model_id": "6543cb991f695e72028e9428", "dataset_name": "Test text generation dataset"}, + {"model_name": "Llama 2 7B Chat", "model_id": "65519ee7bf42e6037ab109d8", "dataset_name": "Test text generation dataset"} +] \ No newline at end of file diff --git a/tests/functional/finetune/data/finetune_test_end2end.json b/tests/functional/finetune/data/finetune_test_end2end.json index cec36422..9682efa2 100644 --- a/tests/functional/finetune/data/finetune_test_end2end.json +++ b/tests/functional/finetune/data/finetune_test_end2end.json @@ -1,11 +1,13 @@ [ { + "model_name": "gpt2", "model_id": "64e615671567f848804985e1", "dataset_name": "Test text generation dataset", "inference_data": "Hello!", "required_dev": true }, { + "model_name": "aiR", "model_id": "6499cc946eb5633de15d82a1", "dataset_name": "Test search dataset", "inference_data": "Hello!", diff --git a/tests/functional/finetune/finetune_functional_test.py b/tests/functional/finetune/finetune_functional_test.py index 8cdbc77c..90b8b7b2 100644 --- a/tests/functional/finetune/finetune_functional_test.py +++ b/tests/functional/finetune/finetune_functional_test.py @@ -31,6 +31,7 @@ TIMEOUT = 20000.0 RUN_FILE = "tests/functional/finetune/data/finetune_test_end2end.json" +ESTIMATE_COST_FILE = "tests/functional/finetune/data/finetune_test_cost_estimation.json" LIST_FILE = "tests/functional/finetune/data/finetune_test_list_data.json" PROMPT_FILE = "tests/functional/finetune/data/finetune_test_prompt_validator.json" @@ -44,6 +45,11 @@ def run_input_map(request): return request.param +@pytest.fixture(scope="module", params=read_data(ESTIMATE_COST_FILE)) +def estimate_cost_input_map(request): + return request.param + + @pytest.fixture(scope="module", params=read_data(LIST_FILE)) def list_input_map(request): return request.param @@ -82,6 +88,17 @@ def test_end2end_text_generation(run_input_map): finetune_model.delete() +def test_cost_estimation_text_generation(estimate_cost_input_map): + model = ModelFactory.get(estimate_cost_input_map["model_id"]) + dataset_list = [DatasetFactory.list(query=estimate_cost_input_map["dataset_name"])["results"][0]] + finetune = FinetuneFactory.create(str(uuid.uuid4()), dataset_list, model) + assert type(finetune.cost) is FinetuneCost + cost_map = finetune.cost.to_dict() + assert "trainingCost" in cost_map + assert "hostingCost" in cost_map + assert "inferenceCost" in cost_map + + def test_list_finetunable_models(list_input_map): model_list = ModelFactory.list( function=Function(list_input_map["function"]),