From b7fa1ec5171e437bcc52e0119665c8c3fcb34cb1 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Wed, 21 Feb 2024 12:20:48 -0300 Subject: [PATCH 1/4] Finetune status object --- aixplain/modules/__init__.py | 1 + aixplain/modules/finetune/status.py | 42 + aixplain/modules/model.py | 9 +- tests/unit/finetune_test.py | 9 +- .../finetune_status_response.json | 978 ++++++++++++++++++ 5 files changed, 1035 insertions(+), 4 deletions(-) create mode 100644 aixplain/modules/finetune/status.py create mode 100644 tests/unit/mock_responses/finetune_status_response.json diff --git a/aixplain/modules/__init__.py b/aixplain/modules/__init__.py index 0902eaf4..bb9e696b 100644 --- a/aixplain/modules/__init__.py +++ b/aixplain/modules/__init__.py @@ -29,5 +29,6 @@ from .model import Model from .pipeline import Pipeline from .finetune import Finetune, FinetuneCost +from .finetune.status import FinetuneStatus from .benchmark import Benchmark from .benchmark_job import BenchmarkJob diff --git a/aixplain/modules/finetune/status.py b/aixplain/modules/finetune/status.py new file mode 100644 index 00000000..01640872 --- /dev/null +++ b/aixplain/modules/finetune/status.py @@ -0,0 +1,42 @@ +__author__ = "thiagocastroferreira" + +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: February 21st 2024 +Description: + FinetuneCost Class +""" + +from dataclasses import dataclass +from dataclasses_json import dataclass_json +from enum import Enum +from typing import Optional, Text + +class FinetuneState(Text, Enum): + ONBOARDING = "onboarding" + ONBOARDED = "onboarded" + FAILED = "failed" + +@dataclass_json +@dataclass +class FinetuneStatus(object): + status: FinetuneState = FinetuneState.ONBOARDING + epoch: Optional[int] = None + step: Optional[int] = None + learning_rate: Optional[float] = None + training_loss: Optional[float] = None + validation_loss: Optional[float] = None diff --git a/aixplain/modules/model.py b/aixplain/modules/model.py index 0804af29..81684ac0 100644 --- a/aixplain/modules/model.py +++ b/aixplain/modules/model.py @@ -29,6 +29,7 @@ from aixplain.factories.file_factory import FileFactory from aixplain.enums import Function, Supplier from aixplain.modules.asset import Asset +from aixplain.modules.finetune.status import FinetuneStatus, FinetuneState from aixplain.utils import config from urllib.parse import urljoin from aixplain.utils.file_utils import _request_with_retry @@ -251,14 +252,18 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param response["error"] = msg return response - def check_finetune_status(self): + def check_finetune_status(self, after_epoch: Optional[int] = None, after_step: Optional[int] = None) -> FinetuneStatus: """Check the status of the FineTune model. + Args: + after_epoch (Optional[int], optional): status after a given epoch. Defaults to None. + after_step (Optional[int], optional): status after a given step. Defaults to None. + Raises: Exception: If the 'TEAM_API_KEY' is not provided. Returns: - str: The status of the FineTune model. + FinetuneStatus: The status of the FineTune model. """ headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} try: diff --git a/tests/unit/finetune_test.py b/tests/unit/finetune_test.py index 5696572b..a43b7b46 100644 --- a/tests/unit/finetune_test.py +++ b/tests/unit/finetune_test.py @@ -35,6 +35,7 @@ COST_ESTIMATION_FILE = "tests/unit/mock_responses/cost_estimation_response.json" FINETUNE_URL = f"{config.BACKEND_URL}/sdk/finetune" FINETUNE_FILE = "tests/unit/mock_responses/finetune_response.json" +FINETUNE_STATUS_FILE = "tests/unit/mock_responses/finetune_status_response.json" PERCENTAGE_EXCEPTION_FILE = "tests/unit/data/create_finetune_percentage_exception.json" MODEL_FILE = "tests/unit/mock_responses/model_response.json" MODEL_URL = f"{config.BACKEND_URL}/sdk/models" @@ -96,14 +97,18 @@ def test_start(): def test_check_finetuner_status(): - model_map = read_data(MODEL_FILE) + model_map = read_data(FINETUNE_STATUS_FILE) asset_id = "test_id" with requests_mock.Mocker() as mock: test_model = Model(asset_id, "") url = f"{MODEL_URL}/{asset_id}" mock.get(url, headers=FIXED_HEADER, json=model_map) status = test_model.check_finetune_status() - assert status == model_map["status"] + assert status.status.value == model_map["status"] + assert status.training_loss == 0.007 + assert status.epoch == 2.75 + assert status.step == 1500 + assert status.learning_rate == 8.088235294117648e-06 @pytest.mark.parametrize("is_finetunable", [True, False]) diff --git a/tests/unit/mock_responses/finetune_status_response.json b/tests/unit/mock_responses/finetune_status_response.json new file mode 100644 index 00000000..d87ddd91 --- /dev/null +++ b/tests/unit/mock_responses/finetune_status_response.json @@ -0,0 +1,978 @@ +{ + "status": "onboarding", + "trainer_state": { + "best_metric": null, + "best_model_checkpoint": null, + "epoch": 2.7529249827942186, + "eval_steps": 200, + "global_step": 1500, + "is_hyper_param_search": false, + "is_local_process_zero": true, + "is_world_process_zero": true, + "log_history": [ + { + "epoch": 0.02, + "learning_rate": 9.938725490196079e-05, + "loss": 0.1106, + "step": 10 + }, + { + "epoch": 0.04, + "learning_rate": 9.877450980392157e-05, + "loss": 0.0482, + "step": 20 + }, + { + "epoch": 0.06, + "learning_rate": 9.816176470588235e-05, + "loss": 0.0251, + "step": 30 + }, + { + "epoch": 0.07, + "learning_rate": 9.754901960784314e-05, + "loss": 0.0228, + "step": 40 + }, + { + "epoch": 0.09, + "learning_rate": 9.693627450980392e-05, + "loss": 0.0217, + "step": 50 + }, + { + "epoch": 0.11, + "learning_rate": 9.632352941176472e-05, + "loss": 0.0126, + "step": 60 + }, + { + "epoch": 0.13, + "learning_rate": 9.57107843137255e-05, + "loss": 0.0111, + "step": 70 + }, + { + "epoch": 0.15, + "learning_rate": 9.509803921568627e-05, + "loss": 0.0162, + "step": 80 + }, + { + "epoch": 0.17, + "learning_rate": 9.448529411764707e-05, + "loss": 0.0149, + "step": 90 + }, + { + "epoch": 0.18, + "learning_rate": 9.387254901960785e-05, + "loss": 0.011, + "step": 100 + }, + { + "epoch": 0.2, + "learning_rate": 9.325980392156863e-05, + "loss": 0.0165, + "step": 110 + }, + { + "epoch": 0.22, + "learning_rate": 9.264705882352942e-05, + "loss": 0.0139, + "step": 120 + }, + { + "epoch": 0.24, + "learning_rate": 9.20343137254902e-05, + "loss": 0.0108, + "step": 130 + }, + { + "epoch": 0.26, + "learning_rate": 9.142156862745098e-05, + "loss": 0.0115, + "step": 140 + }, + { + "epoch": 0.28, + "learning_rate": 9.080882352941177e-05, + "loss": 0.0122, + "step": 150 + }, + { + "epoch": 0.29, + "learning_rate": 9.019607843137255e-05, + "loss": 0.0106, + "step": 160 + }, + { + "epoch": 0.31, + "learning_rate": 8.958333333333335e-05, + "loss": 0.0119, + "step": 170 + }, + { + "epoch": 0.33, + "learning_rate": 8.897058823529412e-05, + "loss": 0.0107, + "step": 180 + }, + { + "epoch": 0.35, + "learning_rate": 8.83578431372549e-05, + "loss": 0.0108, + "step": 190 + }, + { + "epoch": 0.37, + "learning_rate": 8.774509803921568e-05, + "loss": 0.0101, + "step": 200 + }, + { + "epoch": 0.37, + "eval_loss": 0.009765202179551125, + "eval_runtime": 656.7127, + "eval_samples_per_second": 1.325, + "eval_steps_per_second": 0.662, + "step": 200 + }, + { + "epoch": 0.39, + "learning_rate": 8.713235294117648e-05, + "loss": 0.0119, + "step": 210 + }, + { + "epoch": 0.4, + "learning_rate": 8.651960784313726e-05, + "loss": 0.0099, + "step": 220 + }, + { + "epoch": 0.42, + "learning_rate": 8.590686274509803e-05, + "loss": 0.0105, + "step": 230 + }, + { + "epoch": 0.44, + "learning_rate": 8.529411764705883e-05, + "loss": 0.011, + "step": 240 + }, + { + "epoch": 0.46, + "learning_rate": 8.468137254901961e-05, + "loss": 0.0104, + "step": 250 + }, + { + "epoch": 0.48, + "learning_rate": 8.40686274509804e-05, + "loss": 0.0094, + "step": 260 + }, + { + "epoch": 0.5, + "learning_rate": 8.345588235294118e-05, + "loss": 0.0108, + "step": 270 + }, + { + "epoch": 0.51, + "learning_rate": 8.284313725490198e-05, + "loss": 0.0081, + "step": 280 + }, + { + "epoch": 0.53, + "learning_rate": 8.223039215686275e-05, + "loss": 0.0103, + "step": 290 + }, + { + "epoch": 0.55, + "learning_rate": 8.161764705882353e-05, + "loss": 0.01, + "step": 300 + }, + { + "epoch": 0.57, + "learning_rate": 8.100490196078431e-05, + "loss": 0.0111, + "step": 310 + }, + { + "epoch": 0.59, + "learning_rate": 8.039215686274511e-05, + "loss": 0.0097, + "step": 320 + }, + { + "epoch": 0.61, + "learning_rate": 7.977941176470589e-05, + "loss": 0.0093, + "step": 330 + }, + { + "epoch": 0.62, + "learning_rate": 7.916666666666666e-05, + "loss": 0.0093, + "step": 340 + }, + { + "epoch": 0.64, + "learning_rate": 7.855392156862746e-05, + "loss": 0.0104, + "step": 350 + }, + { + "epoch": 0.66, + "learning_rate": 7.794117647058824e-05, + "loss": 0.0094, + "step": 360 + }, + { + "epoch": 0.68, + "learning_rate": 7.732843137254903e-05, + "loss": 0.0099, + "step": 370 + }, + { + "epoch": 0.7, + "learning_rate": 7.671568627450981e-05, + "loss": 0.0092, + "step": 380 + }, + { + "epoch": 0.72, + "learning_rate": 7.610294117647059e-05, + "loss": 0.0082, + "step": 390 + }, + { + "epoch": 0.73, + "learning_rate": 7.549019607843137e-05, + "loss": 0.0098, + "step": 400 + }, + { + "epoch": 0.73, + "eval_loss": 0.008508323691785336, + "eval_runtime": 657.1772, + "eval_samples_per_second": 1.324, + "eval_steps_per_second": 0.662, + "step": 400 + }, + { + "epoch": 0.75, + "learning_rate": 7.487745098039216e-05, + "loss": 0.01, + "step": 410 + }, + { + "epoch": 0.77, + "learning_rate": 7.426470588235294e-05, + "loss": 0.011, + "step": 420 + }, + { + "epoch": 0.79, + "learning_rate": 7.365196078431374e-05, + "loss": 0.0081, + "step": 430 + }, + { + "epoch": 0.81, + "learning_rate": 7.303921568627451e-05, + "loss": 0.01, + "step": 440 + }, + { + "epoch": 0.83, + "learning_rate": 7.242647058823529e-05, + "loss": 0.0088, + "step": 450 + }, + { + "epoch": 0.84, + "learning_rate": 7.181372549019609e-05, + "loss": 0.0101, + "step": 460 + }, + { + "epoch": 0.86, + "learning_rate": 7.120098039215687e-05, + "loss": 0.0082, + "step": 470 + }, + { + "epoch": 0.88, + "learning_rate": 7.058823529411765e-05, + "loss": 0.0091, + "step": 480 + }, + { + "epoch": 0.9, + "learning_rate": 6.997549019607842e-05, + "loss": 0.0085, + "step": 490 + }, + { + "epoch": 0.92, + "learning_rate": 6.936274509803922e-05, + "loss": 0.0094, + "step": 500 + }, + { + "epoch": 0.94, + "learning_rate": 6.875e-05, + "loss": 0.0093, + "step": 510 + }, + { + "epoch": 0.95, + "learning_rate": 6.813725490196079e-05, + "loss": 0.0087, + "step": 520 + }, + { + "epoch": 0.97, + "learning_rate": 6.752450980392157e-05, + "loss": 0.0096, + "step": 530 + }, + { + "epoch": 0.99, + "learning_rate": 6.691176470588235e-05, + "loss": 0.0089, + "step": 540 + }, + { + "epoch": 1.01, + "learning_rate": 6.629901960784314e-05, + "loss": 0.0084, + "step": 550 + }, + { + "epoch": 1.03, + "learning_rate": 6.568627450980392e-05, + "loss": 0.0073, + "step": 560 + }, + { + "epoch": 1.05, + "learning_rate": 6.507352941176472e-05, + "loss": 0.0066, + "step": 570 + }, + { + "epoch": 1.06, + "learning_rate": 6.44607843137255e-05, + "loss": 0.0083, + "step": 580 + }, + { + "epoch": 1.08, + "learning_rate": 6.384803921568627e-05, + "loss": 0.008, + "step": 590 + }, + { + "epoch": 1.1, + "learning_rate": 6.323529411764705e-05, + "loss": 0.0083, + "step": 600 + }, + { + "epoch": 1.1, + "eval_loss": 0.007716518826782703, + "eval_runtime": 657.2554, + "eval_samples_per_second": 1.324, + "eval_steps_per_second": 0.662, + "step": 600 + }, + { + "epoch": 1.12, + "learning_rate": 6.262254901960785e-05, + "loss": 0.0085, + "step": 610 + }, + { + "epoch": 1.14, + "learning_rate": 6.200980392156863e-05, + "loss": 0.0069, + "step": 620 + }, + { + "epoch": 1.16, + "learning_rate": 6.139705882352942e-05, + "loss": 0.0082, + "step": 630 + }, + { + "epoch": 1.17, + "learning_rate": 6.078431372549019e-05, + "loss": 0.0075, + "step": 640 + }, + { + "epoch": 1.19, + "learning_rate": 6.017156862745098e-05, + "loss": 0.0068, + "step": 650 + }, + { + "epoch": 1.21, + "learning_rate": 5.9558823529411766e-05, + "loss": 0.007, + "step": 660 + }, + { + "epoch": 1.23, + "learning_rate": 5.8946078431372556e-05, + "loss": 0.0086, + "step": 670 + }, + { + "epoch": 1.25, + "learning_rate": 5.833333333333334e-05, + "loss": 0.0075, + "step": 680 + }, + { + "epoch": 1.27, + "learning_rate": 5.7720588235294116e-05, + "loss": 0.0071, + "step": 690 + }, + { + "epoch": 1.28, + "learning_rate": 5.71078431372549e-05, + "loss": 0.0083, + "step": 700 + }, + { + "epoch": 1.3, + "learning_rate": 5.649509803921569e-05, + "loss": 0.0069, + "step": 710 + }, + { + "epoch": 1.32, + "learning_rate": 5.588235294117647e-05, + "loss": 0.0091, + "step": 720 + }, + { + "epoch": 1.34, + "learning_rate": 5.526960784313726e-05, + "loss": 0.0067, + "step": 730 + }, + { + "epoch": 1.36, + "learning_rate": 5.465686274509804e-05, + "loss": 0.0067, + "step": 740 + }, + { + "epoch": 1.38, + "learning_rate": 5.404411764705882e-05, + "loss": 0.008, + "step": 750 + }, + { + "epoch": 1.39, + "learning_rate": 5.343137254901961e-05, + "loss": 0.0073, + "step": 760 + }, + { + "epoch": 1.41, + "learning_rate": 5.2818627450980395e-05, + "loss": 0.0076, + "step": 770 + }, + { + "epoch": 1.43, + "learning_rate": 5.2205882352941185e-05, + "loss": 0.0066, + "step": 780 + }, + { + "epoch": 1.45, + "learning_rate": 5.159313725490197e-05, + "loss": 0.0067, + "step": 790 + }, + { + "epoch": 1.47, + "learning_rate": 5.0980392156862745e-05, + "loss": 0.007, + "step": 800 + }, + { + "epoch": 1.47, + "eval_loss": 0.007389526814222336, + "eval_runtime": 656.7058, + "eval_samples_per_second": 1.325, + "eval_steps_per_second": 0.662, + "step": 800 + }, + { + "epoch": 1.49, + "learning_rate": 5.036764705882353e-05, + "loss": 0.0058, + "step": 810 + }, + { + "epoch": 1.5, + "learning_rate": 4.975490196078432e-05, + "loss": 0.0074, + "step": 820 + }, + { + "epoch": 1.52, + "learning_rate": 4.9142156862745095e-05, + "loss": 0.0072, + "step": 830 + }, + { + "epoch": 1.54, + "learning_rate": 4.8529411764705885e-05, + "loss": 0.0058, + "step": 840 + }, + { + "epoch": 1.56, + "learning_rate": 4.791666666666667e-05, + "loss": 0.0076, + "step": 850 + }, + { + "epoch": 1.58, + "learning_rate": 4.730392156862745e-05, + "loss": 0.0088, + "step": 860 + }, + { + "epoch": 1.6, + "learning_rate": 4.669117647058824e-05, + "loss": 0.0091, + "step": 870 + }, + { + "epoch": 1.62, + "learning_rate": 4.607843137254902e-05, + "loss": 0.0088, + "step": 880 + }, + { + "epoch": 1.63, + "learning_rate": 4.546568627450981e-05, + "loss": 0.008, + "step": 890 + }, + { + "epoch": 1.65, + "learning_rate": 4.485294117647059e-05, + "loss": 0.0094, + "step": 900 + }, + { + "epoch": 1.67, + "learning_rate": 4.4240196078431374e-05, + "loss": 0.009, + "step": 910 + }, + { + "epoch": 1.69, + "learning_rate": 4.362745098039216e-05, + "loss": 0.0063, + "step": 920 + }, + { + "epoch": 1.71, + "learning_rate": 4.301470588235295e-05, + "loss": 0.0078, + "step": 930 + }, + { + "epoch": 1.73, + "learning_rate": 4.2401960784313724e-05, + "loss": 0.0062, + "step": 940 + }, + { + "epoch": 1.74, + "learning_rate": 4.1789215686274514e-05, + "loss": 0.0066, + "step": 950 + }, + { + "epoch": 1.76, + "learning_rate": 4.11764705882353e-05, + "loss": 0.0068, + "step": 960 + }, + { + "epoch": 1.78, + "learning_rate": 4.056372549019608e-05, + "loss": 0.0063, + "step": 970 + }, + { + "epoch": 1.8, + "learning_rate": 3.9950980392156864e-05, + "loss": 0.0064, + "step": 980 + }, + { + "epoch": 1.82, + "learning_rate": 3.933823529411765e-05, + "loss": 0.0072, + "step": 990 + }, + { + "epoch": 1.84, + "learning_rate": 3.872549019607844e-05, + "loss": 0.0066, + "step": 1000 + }, + { + "epoch": 1.84, + "eval_loss": 0.007207777351140976, + "eval_runtime": 657.2316, + "eval_samples_per_second": 1.324, + "eval_steps_per_second": 0.662, + "step": 1000 + }, + { + "epoch": 1.85, + "learning_rate": 3.8112745098039213e-05, + "loss": 0.0063, + "step": 1010 + }, + { + "epoch": 1.87, + "learning_rate": 3.7500000000000003e-05, + "loss": 0.0079, + "step": 1020 + }, + { + "epoch": 1.89, + "learning_rate": 3.688725490196079e-05, + "loss": 0.0073, + "step": 1030 + }, + { + "epoch": 1.91, + "learning_rate": 3.627450980392157e-05, + "loss": 0.0058, + "step": 1040 + }, + { + "epoch": 1.93, + "learning_rate": 3.566176470588235e-05, + "loss": 0.0068, + "step": 1050 + }, + { + "epoch": 1.95, + "learning_rate": 3.5049019607843136e-05, + "loss": 0.0065, + "step": 1060 + }, + { + "epoch": 1.96, + "learning_rate": 3.443627450980392e-05, + "loss": 0.0059, + "step": 1070 + }, + { + "epoch": 1.98, + "learning_rate": 3.382352941176471e-05, + "loss": 0.0072, + "step": 1080 + }, + { + "epoch": 2.0, + "learning_rate": 3.321078431372549e-05, + "loss": 0.0077, + "step": 1090 + }, + { + "epoch": 2.02, + "learning_rate": 3.2598039215686276e-05, + "loss": 0.0058, + "step": 1100 + }, + { + "epoch": 2.04, + "learning_rate": 3.198529411764706e-05, + "loss": 0.0052, + "step": 1110 + }, + { + "epoch": 2.06, + "learning_rate": 3.137254901960784e-05, + "loss": 0.0068, + "step": 1120 + }, + { + "epoch": 2.07, + "learning_rate": 3.075980392156863e-05, + "loss": 0.0064, + "step": 1130 + }, + { + "epoch": 2.09, + "learning_rate": 3.0147058823529413e-05, + "loss": 0.0076, + "step": 1140 + }, + { + "epoch": 2.11, + "learning_rate": 2.95343137254902e-05, + "loss": 0.0071, + "step": 1150 + }, + { + "epoch": 2.13, + "learning_rate": 2.8921568627450986e-05, + "loss": 0.0065, + "step": 1160 + }, + { + "epoch": 2.15, + "learning_rate": 2.8308823529411766e-05, + "loss": 0.0076, + "step": 1170 + }, + { + "epoch": 2.17, + "learning_rate": 2.7696078431372552e-05, + "loss": 0.005, + "step": 1180 + }, + { + "epoch": 2.18, + "learning_rate": 2.7083333333333332e-05, + "loss": 0.0064, + "step": 1190 + }, + { + "epoch": 2.2, + "learning_rate": 2.647058823529412e-05, + "loss": 0.0069, + "step": 1200 + }, + { + "epoch": 2.2, + "eval_loss": 0.006944665219634771, + "eval_runtime": 657.2803, + "eval_samples_per_second": 1.324, + "eval_steps_per_second": 0.662, + "step": 1200 + }, + { + "epoch": 2.22, + "learning_rate": 2.5857843137254905e-05, + "loss": 0.0057, + "step": 1210 + }, + { + "epoch": 2.24, + "learning_rate": 2.5245098039215685e-05, + "loss": 0.0062, + "step": 1220 + }, + { + "epoch": 2.26, + "learning_rate": 2.4632352941176472e-05, + "loss": 0.0061, + "step": 1230 + }, + { + "epoch": 2.28, + "learning_rate": 2.401960784313726e-05, + "loss": 0.0048, + "step": 1240 + }, + { + "epoch": 2.29, + "learning_rate": 2.340686274509804e-05, + "loss": 0.0068, + "step": 1250 + }, + { + "epoch": 2.31, + "learning_rate": 2.2794117647058825e-05, + "loss": 0.0059, + "step": 1260 + }, + { + "epoch": 2.33, + "learning_rate": 2.2181372549019608e-05, + "loss": 0.005, + "step": 1270 + }, + { + "epoch": 2.35, + "learning_rate": 2.1568627450980395e-05, + "loss": 0.0058, + "step": 1280 + }, + { + "epoch": 2.37, + "learning_rate": 2.0955882352941178e-05, + "loss": 0.0055, + "step": 1290 + }, + { + "epoch": 2.39, + "learning_rate": 2.034313725490196e-05, + "loss": 0.0067, + "step": 1300 + }, + { + "epoch": 2.4, + "learning_rate": 1.9730392156862744e-05, + "loss": 0.0054, + "step": 1310 + }, + { + "epoch": 2.42, + "learning_rate": 1.9117647058823528e-05, + "loss": 0.0067, + "step": 1320 + }, + { + "epoch": 2.44, + "learning_rate": 1.8504901960784314e-05, + "loss": 0.0061, + "step": 1330 + }, + { + "epoch": 2.46, + "learning_rate": 1.7892156862745098e-05, + "loss": 0.0061, + "step": 1340 + }, + { + "epoch": 2.48, + "learning_rate": 1.7279411764705884e-05, + "loss": 0.0058, + "step": 1350 + }, + { + "epoch": 2.5, + "learning_rate": 1.6666666666666667e-05, + "loss": 0.0077, + "step": 1360 + }, + { + "epoch": 2.51, + "learning_rate": 1.6053921568627454e-05, + "loss": 0.0065, + "step": 1370 + }, + { + "epoch": 2.53, + "learning_rate": 1.5441176470588237e-05, + "loss": 0.0055, + "step": 1380 + }, + { + "epoch": 2.55, + "learning_rate": 1.482843137254902e-05, + "loss": 0.0058, + "step": 1390 + }, + { + "epoch": 2.57, + "learning_rate": 1.4215686274509804e-05, + "loss": 0.0057, + "step": 1400 + }, + { + "epoch": 2.57, + "eval_loss": 0.006838853470981121, + "eval_runtime": 656.7781, + "eval_samples_per_second": 1.325, + "eval_steps_per_second": 0.662, + "step": 1400 + }, + { + "epoch": 2.59, + "learning_rate": 1.3602941176470587e-05, + "loss": 0.0058, + "step": 1410 + }, + { + "epoch": 2.61, + "learning_rate": 1.2990196078431374e-05, + "loss": 0.0072, + "step": 1420 + }, + { + "epoch": 2.62, + "learning_rate": 1.2377450980392159e-05, + "loss": 0.0052, + "step": 1430 + }, + { + "epoch": 2.64, + "learning_rate": 1.1764705882352942e-05, + "loss": 0.0051, + "step": 1440 + }, + { + "epoch": 2.66, + "learning_rate": 1.1151960784313727e-05, + "loss": 0.0063, + "step": 1450 + }, + { + "epoch": 2.68, + "learning_rate": 1.053921568627451e-05, + "loss": 0.0062, + "step": 1460 + }, + { + "epoch": 2.7, + "learning_rate": 9.926470588235293e-06, + "loss": 0.005, + "step": 1470 + }, + { + "epoch": 2.72, + "learning_rate": 9.31372549019608e-06, + "loss": 0.0059, + "step": 1480 + }, + { + "epoch": 2.73, + "learning_rate": 8.700980392156863e-06, + "loss": 0.0055, + "step": 1490 + }, + { + "epoch": 2.75, + "learning_rate": 8.088235294117648e-06, + "loss": 0.007, + "step": 1500 + } + ], + "logging_steps": 10, + "max_steps": 1632, + "num_train_epochs": 3, + "save_steps": 500, + "total_flos": 3.492274811692892e+18, + "trial_name": null, + "trial_params": null + } +} \ No newline at end of file From c26533527c604e08758f9d10f328597ead7b5856 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Tue, 12 Mar 2024 15:18:40 -0300 Subject: [PATCH 2/4] New finetune status checker --- aixplain/enums/asset_status.py | 43 + aixplain/modules/finetune/status.py | 12 +- aixplain/modules/metric.py | 3 +- aixplain/modules/model.py | 41 +- tests/unit/finetune_test.py | 25 +- .../finetune_status_response.json | 1015 +---------------- 6 files changed, 138 insertions(+), 1001 deletions(-) create mode 100644 aixplain/enums/asset_status.py diff --git a/aixplain/enums/asset_status.py b/aixplain/enums/asset_status.py new file mode 100644 index 00000000..134af26e --- /dev/null +++ b/aixplain/enums/asset_status.py @@ -0,0 +1,43 @@ +__author__ = "thiagocastroferreira" + +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: February 21st 2024 +Description: + Asset Enum +""" + +from enum import Enum +from typing import Text + +class AssetStatus(Text, Enum): + HIDDEN = 'hidden' + SCHEDULED = 'scheduled' + ONBOARDING = 'onboarding' + ONBOARDED = 'onboarded' + PENDING = 'pending' + FAILED = 'failed' + TRAINING = 'training' + REJECTED = 'rejected' + ENABLING = 'enabling' + DELETING = 'deleting' + DISABLED = 'disabled' + DELETED = 'deleted' + IN_PROGRESS = 'in_progress' + COMPLETED = 'completed' + CANCELING = 'canceling' + CANCELED = 'canceled' \ No newline at end of file diff --git a/aixplain/modules/finetune/status.py b/aixplain/modules/finetune/status.py index 01640872..89c0cd16 100644 --- a/aixplain/modules/finetune/status.py +++ b/aixplain/modules/finetune/status.py @@ -21,21 +21,17 @@ FinetuneCost Class """ +from aixplain.enums.asset_status import AssetStatus from dataclasses import dataclass from dataclasses_json import dataclass_json -from enum import Enum from typing import Optional, Text -class FinetuneState(Text, Enum): - ONBOARDING = "onboarding" - ONBOARDED = "onboarded" - FAILED = "failed" - @dataclass_json @dataclass class FinetuneStatus(object): - status: FinetuneState = FinetuneState.ONBOARDING - epoch: Optional[int] = None + status: "AssetStatus" + model_status: "AssetStatus" + epoch: Optional[float] = None step: Optional[int] = None learning_rate: Optional[float] = None training_loss: Optional[float] = None diff --git a/aixplain/modules/metric.py b/aixplain/modules/metric.py index 8d8844f0..4bcefa37 100644 --- a/aixplain/modules/metric.py +++ b/aixplain/modules/metric.py @@ -23,8 +23,6 @@ from typing import Optional, Text, List, Union from aixplain.modules.asset import Asset -from aixplain.utils.file_utils import _request_with_retry -from aixplain.factories.model_factory import ModelFactory class Metric(Asset): @@ -92,6 +90,7 @@ def run(self, hypothesis: Optional[Union[str, List[str]]]=None, source: Optional source (Optional[Union[str, List[str]]], optional): Can give a single source or a list of sources for metric calculation. Defaults to None. reference (Optional[Union[str, List[str]]], optional): Can give a single reference or a list of references for metric calculation. Defaults to None. """ + from aixplain.factories.model_factory import ModelFactory model = ModelFactory.get(self.id) payload = { "function": self.function, diff --git a/aixplain/modules/model.py b/aixplain/modules/model.py index 81684ac0..985444f8 100644 --- a/aixplain/modules/model.py +++ b/aixplain/modules/model.py @@ -20,7 +20,6 @@ Description: Model Class """ - import time import json import logging @@ -29,7 +28,6 @@ from aixplain.factories.file_factory import FileFactory from aixplain.enums import Function, Supplier from aixplain.modules.asset import Asset -from aixplain.modules.finetune.status import FinetuneStatus, FinetuneState from aixplain.utils import config from urllib.parse import urljoin from aixplain.utils.file_utils import _request_with_retry @@ -252,7 +250,7 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param response["error"] = msg return response - def check_finetune_status(self, after_epoch: Optional[int] = None, after_step: Optional[int] = None) -> FinetuneStatus: + def check_finetune_status(self, after_epoch: Optional[int] = None, after_step: Optional[int] = None): """Check the status of the FineTune model. Args: @@ -265,14 +263,45 @@ def check_finetune_status(self, after_epoch: Optional[int] = None, after_step: O Returns: FinetuneStatus: The status of the FineTune model. """ + from aixplain.enums.asset_status import AssetStatus + from aixplain.modules.finetune.status import FinetuneStatus headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} + resp = None try: - url = urljoin(self.backend_url, f"sdk/models/{self.id}") + url = urljoin(self.backend_url, f"sdk/finetune/{self.id}/ml-logs") logging.info(f"Start service for GET Check FineTune status Model - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() - status = resp["status"] - logging.info(f"Response for GET Check FineTune status Model - Id {self.id} / Status {status}.") + finetune_status = AssetStatus(resp["finetuneStatus"]) + model_status = AssetStatus(resp["modelStatus"]) + logs = sorted(resp["logs"], key=lambda x: (float(x["epoch"]), int(x["step"]))) + + if after_epoch is None and after_step is None: + logs = logs[-1:] + else: + if after_epoch is not None: + logs = [log for log in logs if float(log["epoch"]) >= after_epoch] + if after_step is not None: + logs = [log for log in logs if log["step"] >= after_step] + + if len(logs) > 0: + log = logs[0] + status = FinetuneStatus( + status=finetune_status, + model_status=model_status, + epoch=float(log["epoch"]), + step=int(log["step"]), + learning_rate=float(log["learningRate"]), + training_loss=float(log["trainLoss"]), + validation_loss=float(log["validationLoss"]), + ) + else: + status = FinetuneStatus( + status=finetune_status, + model_status=model_status, + ) + + logging.info(f"Response for GET Check FineTune status Model - Id {self.id} / Status {status.status.value}.") return status except Exception as e: message = "" diff --git a/tests/unit/finetune_test.py b/tests/unit/finetune_test.py index a43b7b46..49f48c70 100644 --- a/tests/unit/finetune_test.py +++ b/tests/unit/finetune_test.py @@ -27,6 +27,7 @@ from aixplain.factories import FinetuneFactory from aixplain.modules import Model, Finetune from aixplain.enums import Function +from urllib.parse import urljoin import pytest @@ -95,20 +96,26 @@ def test_start(): assert fine_tuned_model is not None assert fine_tuned_model.id == model_map["id"] - -def test_check_finetuner_status(): +@pytest.mark.parametrize( + "after_epoch,after_step,training_loss,validation_loss", + [ + (None, None, 0.4, 0.0217), + (1, 10, 0.1, 0.1106), + (1, 20, 0.2, 0.0482) + ] +) +def test_check_finetuner_status(after_epoch, after_step, training_loss, validation_loss): model_map = read_data(FINETUNE_STATUS_FILE) asset_id = "test_id" with requests_mock.Mocker() as mock: test_model = Model(asset_id, "") - url = f"{MODEL_URL}/{asset_id}" + url = urljoin(config.BACKEND_URL, f"sdk/finetune/{asset_id}/ml-logs") mock.get(url, headers=FIXED_HEADER, json=model_map) - status = test_model.check_finetune_status() - assert status.status.value == model_map["status"] - assert status.training_loss == 0.007 - assert status.epoch == 2.75 - assert status.step == 1500 - assert status.learning_rate == 8.088235294117648e-06 + status = test_model.check_finetune_status(after_epoch=after_epoch, after_step=after_step) + assert status.status.value == model_map["finetuneStatus"] + assert status.model_status.value == model_map["modelStatus"] + assert status.training_loss == training_loss + assert status.validation_loss == validation_loss @pytest.mark.parametrize("is_finetunable", [True, False]) diff --git a/tests/unit/mock_responses/finetune_status_response.json b/tests/unit/mock_responses/finetune_status_response.json index d87ddd91..a8a080d3 100644 --- a/tests/unit/mock_responses/finetune_status_response.json +++ b/tests/unit/mock_responses/finetune_status_response.json @@ -1,978 +1,41 @@ { - "status": "onboarding", - "trainer_state": { - "best_metric": null, - "best_model_checkpoint": null, - "epoch": 2.7529249827942186, - "eval_steps": 200, - "global_step": 1500, - "is_hyper_param_search": false, - "is_local_process_zero": true, - "is_world_process_zero": true, - "log_history": [ - { - "epoch": 0.02, - "learning_rate": 9.938725490196079e-05, - "loss": 0.1106, - "step": 10 - }, - { - "epoch": 0.04, - "learning_rate": 9.877450980392157e-05, - "loss": 0.0482, - "step": 20 - }, - { - "epoch": 0.06, - "learning_rate": 9.816176470588235e-05, - "loss": 0.0251, - "step": 30 - }, - { - "epoch": 0.07, - "learning_rate": 9.754901960784314e-05, - "loss": 0.0228, - "step": 40 - }, - { - "epoch": 0.09, - "learning_rate": 9.693627450980392e-05, - "loss": 0.0217, - "step": 50 - }, - { - "epoch": 0.11, - "learning_rate": 9.632352941176472e-05, - "loss": 0.0126, - "step": 60 - }, - { - "epoch": 0.13, - "learning_rate": 9.57107843137255e-05, - "loss": 0.0111, - "step": 70 - }, - { - "epoch": 0.15, - "learning_rate": 9.509803921568627e-05, - "loss": 0.0162, - "step": 80 - }, - { - "epoch": 0.17, - "learning_rate": 9.448529411764707e-05, - "loss": 0.0149, - "step": 90 - }, - { - "epoch": 0.18, - "learning_rate": 9.387254901960785e-05, - "loss": 0.011, - "step": 100 - }, - { - "epoch": 0.2, - "learning_rate": 9.325980392156863e-05, - "loss": 0.0165, - "step": 110 - }, - { - "epoch": 0.22, - "learning_rate": 9.264705882352942e-05, - "loss": 0.0139, - "step": 120 - }, - { - "epoch": 0.24, - "learning_rate": 9.20343137254902e-05, - "loss": 0.0108, - "step": 130 - }, - { - "epoch": 0.26, - "learning_rate": 9.142156862745098e-05, - "loss": 0.0115, - "step": 140 - }, - { - "epoch": 0.28, - "learning_rate": 9.080882352941177e-05, - "loss": 0.0122, - "step": 150 - }, - { - "epoch": 0.29, - "learning_rate": 9.019607843137255e-05, - "loss": 0.0106, - "step": 160 - }, - { - "epoch": 0.31, - "learning_rate": 8.958333333333335e-05, - "loss": 0.0119, - "step": 170 - }, - { - "epoch": 0.33, - "learning_rate": 8.897058823529412e-05, - "loss": 0.0107, - "step": 180 - }, - { - "epoch": 0.35, - "learning_rate": 8.83578431372549e-05, - "loss": 0.0108, - "step": 190 - }, - { - "epoch": 0.37, - "learning_rate": 8.774509803921568e-05, - "loss": 0.0101, - "step": 200 - }, - { - "epoch": 0.37, - "eval_loss": 0.009765202179551125, - "eval_runtime": 656.7127, - "eval_samples_per_second": 1.325, - "eval_steps_per_second": 0.662, - "step": 200 - }, - { - "epoch": 0.39, - "learning_rate": 8.713235294117648e-05, - "loss": 0.0119, - "step": 210 - }, - { - "epoch": 0.4, - "learning_rate": 8.651960784313726e-05, - "loss": 0.0099, - "step": 220 - }, - { - "epoch": 0.42, - "learning_rate": 8.590686274509803e-05, - "loss": 0.0105, - "step": 230 - }, - { - "epoch": 0.44, - "learning_rate": 8.529411764705883e-05, - "loss": 0.011, - "step": 240 - }, - { - "epoch": 0.46, - "learning_rate": 8.468137254901961e-05, - "loss": 0.0104, - "step": 250 - }, - { - "epoch": 0.48, - "learning_rate": 8.40686274509804e-05, - "loss": 0.0094, - "step": 260 - }, - { - "epoch": 0.5, - "learning_rate": 8.345588235294118e-05, - "loss": 0.0108, - "step": 270 - }, - { - "epoch": 0.51, - "learning_rate": 8.284313725490198e-05, - "loss": 0.0081, - "step": 280 - }, - { - "epoch": 0.53, - "learning_rate": 8.223039215686275e-05, - "loss": 0.0103, - "step": 290 - }, - { - "epoch": 0.55, - "learning_rate": 8.161764705882353e-05, - "loss": 0.01, - "step": 300 - }, - { - "epoch": 0.57, - "learning_rate": 8.100490196078431e-05, - "loss": 0.0111, - "step": 310 - }, - { - "epoch": 0.59, - "learning_rate": 8.039215686274511e-05, - "loss": 0.0097, - "step": 320 - }, - { - "epoch": 0.61, - "learning_rate": 7.977941176470589e-05, - "loss": 0.0093, - "step": 330 - }, - { - "epoch": 0.62, - "learning_rate": 7.916666666666666e-05, - "loss": 0.0093, - "step": 340 - }, - { - "epoch": 0.64, - "learning_rate": 7.855392156862746e-05, - "loss": 0.0104, - "step": 350 - }, - { - "epoch": 0.66, - "learning_rate": 7.794117647058824e-05, - "loss": 0.0094, - "step": 360 - }, - { - "epoch": 0.68, - "learning_rate": 7.732843137254903e-05, - "loss": 0.0099, - "step": 370 - }, - { - "epoch": 0.7, - "learning_rate": 7.671568627450981e-05, - "loss": 0.0092, - "step": 380 - }, - { - "epoch": 0.72, - "learning_rate": 7.610294117647059e-05, - "loss": 0.0082, - "step": 390 - }, - { - "epoch": 0.73, - "learning_rate": 7.549019607843137e-05, - "loss": 0.0098, - "step": 400 - }, - { - "epoch": 0.73, - "eval_loss": 0.008508323691785336, - "eval_runtime": 657.1772, - "eval_samples_per_second": 1.324, - "eval_steps_per_second": 0.662, - "step": 400 - }, - { - "epoch": 0.75, - "learning_rate": 7.487745098039216e-05, - "loss": 0.01, - "step": 410 - }, - { - "epoch": 0.77, - "learning_rate": 7.426470588235294e-05, - "loss": 0.011, - "step": 420 - }, - { - "epoch": 0.79, - "learning_rate": 7.365196078431374e-05, - "loss": 0.0081, - "step": 430 - }, - { - "epoch": 0.81, - "learning_rate": 7.303921568627451e-05, - "loss": 0.01, - "step": 440 - }, - { - "epoch": 0.83, - "learning_rate": 7.242647058823529e-05, - "loss": 0.0088, - "step": 450 - }, - { - "epoch": 0.84, - "learning_rate": 7.181372549019609e-05, - "loss": 0.0101, - "step": 460 - }, - { - "epoch": 0.86, - "learning_rate": 7.120098039215687e-05, - "loss": 0.0082, - "step": 470 - }, - { - "epoch": 0.88, - "learning_rate": 7.058823529411765e-05, - "loss": 0.0091, - "step": 480 - }, - { - "epoch": 0.9, - "learning_rate": 6.997549019607842e-05, - "loss": 0.0085, - "step": 490 - }, - { - "epoch": 0.92, - "learning_rate": 6.936274509803922e-05, - "loss": 0.0094, - "step": 500 - }, - { - "epoch": 0.94, - "learning_rate": 6.875e-05, - "loss": 0.0093, - "step": 510 - }, - { - "epoch": 0.95, - "learning_rate": 6.813725490196079e-05, - "loss": 0.0087, - "step": 520 - }, - { - "epoch": 0.97, - "learning_rate": 6.752450980392157e-05, - "loss": 0.0096, - "step": 530 - }, - { - "epoch": 0.99, - "learning_rate": 6.691176470588235e-05, - "loss": 0.0089, - "step": 540 - }, - { - "epoch": 1.01, - "learning_rate": 6.629901960784314e-05, - "loss": 0.0084, - "step": 550 - }, - { - "epoch": 1.03, - "learning_rate": 6.568627450980392e-05, - "loss": 0.0073, - "step": 560 - }, - { - "epoch": 1.05, - "learning_rate": 6.507352941176472e-05, - "loss": 0.0066, - "step": 570 - }, - { - "epoch": 1.06, - "learning_rate": 6.44607843137255e-05, - "loss": 0.0083, - "step": 580 - }, - { - "epoch": 1.08, - "learning_rate": 6.384803921568627e-05, - "loss": 0.008, - "step": 590 - }, - { - "epoch": 1.1, - "learning_rate": 6.323529411764705e-05, - "loss": 0.0083, - "step": 600 - }, - { - "epoch": 1.1, - "eval_loss": 0.007716518826782703, - "eval_runtime": 657.2554, - "eval_samples_per_second": 1.324, - "eval_steps_per_second": 0.662, - "step": 600 - }, - { - "epoch": 1.12, - "learning_rate": 6.262254901960785e-05, - "loss": 0.0085, - "step": 610 - }, - { - "epoch": 1.14, - "learning_rate": 6.200980392156863e-05, - "loss": 0.0069, - "step": 620 - }, - { - "epoch": 1.16, - "learning_rate": 6.139705882352942e-05, - "loss": 0.0082, - "step": 630 - }, - { - "epoch": 1.17, - "learning_rate": 6.078431372549019e-05, - "loss": 0.0075, - "step": 640 - }, - { - "epoch": 1.19, - "learning_rate": 6.017156862745098e-05, - "loss": 0.0068, - "step": 650 - }, - { - "epoch": 1.21, - "learning_rate": 5.9558823529411766e-05, - "loss": 0.007, - "step": 660 - }, - { - "epoch": 1.23, - "learning_rate": 5.8946078431372556e-05, - "loss": 0.0086, - "step": 670 - }, - { - "epoch": 1.25, - "learning_rate": 5.833333333333334e-05, - "loss": 0.0075, - "step": 680 - }, - { - "epoch": 1.27, - "learning_rate": 5.7720588235294116e-05, - "loss": 0.0071, - "step": 690 - }, - { - "epoch": 1.28, - "learning_rate": 5.71078431372549e-05, - "loss": 0.0083, - "step": 700 - }, - { - "epoch": 1.3, - "learning_rate": 5.649509803921569e-05, - "loss": 0.0069, - "step": 710 - }, - { - "epoch": 1.32, - "learning_rate": 5.588235294117647e-05, - "loss": 0.0091, - "step": 720 - }, - { - "epoch": 1.34, - "learning_rate": 5.526960784313726e-05, - "loss": 0.0067, - "step": 730 - }, - { - "epoch": 1.36, - "learning_rate": 5.465686274509804e-05, - "loss": 0.0067, - "step": 740 - }, - { - "epoch": 1.38, - "learning_rate": 5.404411764705882e-05, - "loss": 0.008, - "step": 750 - }, - { - "epoch": 1.39, - "learning_rate": 5.343137254901961e-05, - "loss": 0.0073, - "step": 760 - }, - { - "epoch": 1.41, - "learning_rate": 5.2818627450980395e-05, - "loss": 0.0076, - "step": 770 - }, - { - "epoch": 1.43, - "learning_rate": 5.2205882352941185e-05, - "loss": 0.0066, - "step": 780 - }, - { - "epoch": 1.45, - "learning_rate": 5.159313725490197e-05, - "loss": 0.0067, - "step": 790 - }, - { - "epoch": 1.47, - "learning_rate": 5.0980392156862745e-05, - "loss": 0.007, - "step": 800 - }, - { - "epoch": 1.47, - "eval_loss": 0.007389526814222336, - "eval_runtime": 656.7058, - "eval_samples_per_second": 1.325, - "eval_steps_per_second": 0.662, - "step": 800 - }, - { - "epoch": 1.49, - "learning_rate": 5.036764705882353e-05, - "loss": 0.0058, - "step": 810 - }, - { - "epoch": 1.5, - "learning_rate": 4.975490196078432e-05, - "loss": 0.0074, - "step": 820 - }, - { - "epoch": 1.52, - "learning_rate": 4.9142156862745095e-05, - "loss": 0.0072, - "step": 830 - }, - { - "epoch": 1.54, - "learning_rate": 4.8529411764705885e-05, - "loss": 0.0058, - "step": 840 - }, - { - "epoch": 1.56, - "learning_rate": 4.791666666666667e-05, - "loss": 0.0076, - "step": 850 - }, - { - "epoch": 1.58, - "learning_rate": 4.730392156862745e-05, - "loss": 0.0088, - "step": 860 - }, - { - "epoch": 1.6, - "learning_rate": 4.669117647058824e-05, - "loss": 0.0091, - "step": 870 - }, - { - "epoch": 1.62, - "learning_rate": 4.607843137254902e-05, - "loss": 0.0088, - "step": 880 - }, - { - "epoch": 1.63, - "learning_rate": 4.546568627450981e-05, - "loss": 0.008, - "step": 890 - }, - { - "epoch": 1.65, - "learning_rate": 4.485294117647059e-05, - "loss": 0.0094, - "step": 900 - }, - { - "epoch": 1.67, - "learning_rate": 4.4240196078431374e-05, - "loss": 0.009, - "step": 910 - }, - { - "epoch": 1.69, - "learning_rate": 4.362745098039216e-05, - "loss": 0.0063, - "step": 920 - }, - { - "epoch": 1.71, - "learning_rate": 4.301470588235295e-05, - "loss": 0.0078, - "step": 930 - }, - { - "epoch": 1.73, - "learning_rate": 4.2401960784313724e-05, - "loss": 0.0062, - "step": 940 - }, - { - "epoch": 1.74, - "learning_rate": 4.1789215686274514e-05, - "loss": 0.0066, - "step": 950 - }, - { - "epoch": 1.76, - "learning_rate": 4.11764705882353e-05, - "loss": 0.0068, - "step": 960 - }, - { - "epoch": 1.78, - "learning_rate": 4.056372549019608e-05, - "loss": 0.0063, - "step": 970 - }, - { - "epoch": 1.8, - "learning_rate": 3.9950980392156864e-05, - "loss": 0.0064, - "step": 980 - }, - { - "epoch": 1.82, - "learning_rate": 3.933823529411765e-05, - "loss": 0.0072, - "step": 990 - }, - { - "epoch": 1.84, - "learning_rate": 3.872549019607844e-05, - "loss": 0.0066, - "step": 1000 - }, - { - "epoch": 1.84, - "eval_loss": 0.007207777351140976, - "eval_runtime": 657.2316, - "eval_samples_per_second": 1.324, - "eval_steps_per_second": 0.662, - "step": 1000 - }, - { - "epoch": 1.85, - "learning_rate": 3.8112745098039213e-05, - "loss": 0.0063, - "step": 1010 - }, - { - "epoch": 1.87, - "learning_rate": 3.7500000000000003e-05, - "loss": 0.0079, - "step": 1020 - }, - { - "epoch": 1.89, - "learning_rate": 3.688725490196079e-05, - "loss": 0.0073, - "step": 1030 - }, - { - "epoch": 1.91, - "learning_rate": 3.627450980392157e-05, - "loss": 0.0058, - "step": 1040 - }, - { - "epoch": 1.93, - "learning_rate": 3.566176470588235e-05, - "loss": 0.0068, - "step": 1050 - }, - { - "epoch": 1.95, - "learning_rate": 3.5049019607843136e-05, - "loss": 0.0065, - "step": 1060 - }, - { - "epoch": 1.96, - "learning_rate": 3.443627450980392e-05, - "loss": 0.0059, - "step": 1070 - }, - { - "epoch": 1.98, - "learning_rate": 3.382352941176471e-05, - "loss": 0.0072, - "step": 1080 - }, - { - "epoch": 2.0, - "learning_rate": 3.321078431372549e-05, - "loss": 0.0077, - "step": 1090 - }, - { - "epoch": 2.02, - "learning_rate": 3.2598039215686276e-05, - "loss": 0.0058, - "step": 1100 - }, - { - "epoch": 2.04, - "learning_rate": 3.198529411764706e-05, - "loss": 0.0052, - "step": 1110 - }, - { - "epoch": 2.06, - "learning_rate": 3.137254901960784e-05, - "loss": 0.0068, - "step": 1120 - }, - { - "epoch": 2.07, - "learning_rate": 3.075980392156863e-05, - "loss": 0.0064, - "step": 1130 - }, - { - "epoch": 2.09, - "learning_rate": 3.0147058823529413e-05, - "loss": 0.0076, - "step": 1140 - }, - { - "epoch": 2.11, - "learning_rate": 2.95343137254902e-05, - "loss": 0.0071, - "step": 1150 - }, - { - "epoch": 2.13, - "learning_rate": 2.8921568627450986e-05, - "loss": 0.0065, - "step": 1160 - }, - { - "epoch": 2.15, - "learning_rate": 2.8308823529411766e-05, - "loss": 0.0076, - "step": 1170 - }, - { - "epoch": 2.17, - "learning_rate": 2.7696078431372552e-05, - "loss": 0.005, - "step": 1180 - }, - { - "epoch": 2.18, - "learning_rate": 2.7083333333333332e-05, - "loss": 0.0064, - "step": 1190 - }, - { - "epoch": 2.2, - "learning_rate": 2.647058823529412e-05, - "loss": 0.0069, - "step": 1200 - }, - { - "epoch": 2.2, - "eval_loss": 0.006944665219634771, - "eval_runtime": 657.2803, - "eval_samples_per_second": 1.324, - "eval_steps_per_second": 0.662, - "step": 1200 - }, - { - "epoch": 2.22, - "learning_rate": 2.5857843137254905e-05, - "loss": 0.0057, - "step": 1210 - }, - { - "epoch": 2.24, - "learning_rate": 2.5245098039215685e-05, - "loss": 0.0062, - "step": 1220 - }, - { - "epoch": 2.26, - "learning_rate": 2.4632352941176472e-05, - "loss": 0.0061, - "step": 1230 - }, - { - "epoch": 2.28, - "learning_rate": 2.401960784313726e-05, - "loss": 0.0048, - "step": 1240 - }, - { - "epoch": 2.29, - "learning_rate": 2.340686274509804e-05, - "loss": 0.0068, - "step": 1250 - }, - { - "epoch": 2.31, - "learning_rate": 2.2794117647058825e-05, - "loss": 0.0059, - "step": 1260 - }, - { - "epoch": 2.33, - "learning_rate": 2.2181372549019608e-05, - "loss": 0.005, - "step": 1270 - }, - { - "epoch": 2.35, - "learning_rate": 2.1568627450980395e-05, - "loss": 0.0058, - "step": 1280 - }, - { - "epoch": 2.37, - "learning_rate": 2.0955882352941178e-05, - "loss": 0.0055, - "step": 1290 - }, - { - "epoch": 2.39, - "learning_rate": 2.034313725490196e-05, - "loss": 0.0067, - "step": 1300 - }, - { - "epoch": 2.4, - "learning_rate": 1.9730392156862744e-05, - "loss": 0.0054, - "step": 1310 - }, - { - "epoch": 2.42, - "learning_rate": 1.9117647058823528e-05, - "loss": 0.0067, - "step": 1320 - }, - { - "epoch": 2.44, - "learning_rate": 1.8504901960784314e-05, - "loss": 0.0061, - "step": 1330 - }, - { - "epoch": 2.46, - "learning_rate": 1.7892156862745098e-05, - "loss": 0.0061, - "step": 1340 - }, - { - "epoch": 2.48, - "learning_rate": 1.7279411764705884e-05, - "loss": 0.0058, - "step": 1350 - }, - { - "epoch": 2.5, - "learning_rate": 1.6666666666666667e-05, - "loss": 0.0077, - "step": 1360 - }, - { - "epoch": 2.51, - "learning_rate": 1.6053921568627454e-05, - "loss": 0.0065, - "step": 1370 - }, - { - "epoch": 2.53, - "learning_rate": 1.5441176470588237e-05, - "loss": 0.0055, - "step": 1380 - }, - { - "epoch": 2.55, - "learning_rate": 1.482843137254902e-05, - "loss": 0.0058, - "step": 1390 - }, - { - "epoch": 2.57, - "learning_rate": 1.4215686274509804e-05, - "loss": 0.0057, - "step": 1400 - }, - { - "epoch": 2.57, - "eval_loss": 0.006838853470981121, - "eval_runtime": 656.7781, - "eval_samples_per_second": 1.325, - "eval_steps_per_second": 0.662, - "step": 1400 - }, - { - "epoch": 2.59, - "learning_rate": 1.3602941176470587e-05, - "loss": 0.0058, - "step": 1410 - }, - { - "epoch": 2.61, - "learning_rate": 1.2990196078431374e-05, - "loss": 0.0072, - "step": 1420 - }, - { - "epoch": 2.62, - "learning_rate": 1.2377450980392159e-05, - "loss": 0.0052, - "step": 1430 - }, - { - "epoch": 2.64, - "learning_rate": 1.1764705882352942e-05, - "loss": 0.0051, - "step": 1440 - }, - { - "epoch": 2.66, - "learning_rate": 1.1151960784313727e-05, - "loss": 0.0063, - "step": 1450 - }, - { - "epoch": 2.68, - "learning_rate": 1.053921568627451e-05, - "loss": 0.0062, - "step": 1460 - }, - { - "epoch": 2.7, - "learning_rate": 9.926470588235293e-06, - "loss": 0.005, - "step": 1470 - }, - { - "epoch": 2.72, - "learning_rate": 9.31372549019608e-06, - "loss": 0.0059, - "step": 1480 - }, - { - "epoch": 2.73, - "learning_rate": 8.700980392156863e-06, - "loss": 0.0055, - "step": 1490 - }, - { - "epoch": 2.75, - "learning_rate": 8.088235294117648e-06, - "loss": 0.007, - "step": 1500 - } - ], - "logging_steps": 10, - "max_steps": 1632, - "num_train_epochs": 3, - "save_steps": 500, - "total_flos": 3.492274811692892e+18, - "trial_name": null, - "trial_params": null - } + "finetuneStatus": "onboarding", + "modelStatus": "onboarded", + "logs": [ + { + "epoch": 1, + "learningRate": 9.938725490196079e-05, + "trainLoss": 0.1, + "validationLoss": 0.1106, + "step": 10 + }, + { + "epoch": 2, + "learningRate": 9.877450980392157e-05, + "trainLoss": 0.2, + "validationLoss": 0.0482, + "step": 20 + }, + { + "epoch": 3, + "learningRate": 9.816176470588235e-05, + "trainLoss": 0.3, + "validationLoss": 0.0251, + "step": 30 + }, + { + "epoch": 4, + "learningRate": 9.754901960784314e-05, + "trainLoss": 0.9, + "validationLoss": 0.0228, + "step": 40 + }, + { + "epoch": 5, + "learningRate": 9.693627450980392e-05, + "trainLoss": 0.4, + "validationLoss": 0.0217, + "step": 50 + } + ] } \ No newline at end of file From a8f9607f9a8ad2fbdb2281e457a4ab9c8f43693b Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Thu, 14 Mar 2024 16:53:41 -0300 Subject: [PATCH 3/4] Covering some None cases --- aixplain/modules/model.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/aixplain/modules/model.py b/aixplain/modules/model.py index 985444f8..6e2e89e7 100644 --- a/aixplain/modules/model.py +++ b/aixplain/modules/model.py @@ -274,7 +274,11 @@ def check_finetune_status(self, after_epoch: Optional[int] = None, after_step: O resp = r.json() finetune_status = AssetStatus(resp["finetuneStatus"]) model_status = AssetStatus(resp["modelStatus"]) - logs = sorted(resp["logs"], key=lambda x: (float(x["epoch"]), int(x["step"]))) + try: + logs = sorted(resp["logs"], key=lambda x: (float(x["epoch"]), int(x["step"]))) + except Exception: + # if step is not stored + logs = sorted(resp["logs"], key=lambda x: float(x["epoch"])) if after_epoch is None and after_step is None: logs = logs[-1:] @@ -282,18 +286,19 @@ def check_finetune_status(self, after_epoch: Optional[int] = None, after_step: O if after_epoch is not None: logs = [log for log in logs if float(log["epoch"]) >= after_epoch] if after_step is not None: - logs = [log for log in logs if log["step"] >= after_step] + if len(logs) > 0 and "step" in logs[0]: + logs = [log for log in logs if log["step"] >= after_step] if len(logs) > 0: log = logs[0] status = FinetuneStatus( status=finetune_status, model_status=model_status, - epoch=float(log["epoch"]), - step=int(log["step"]), - learning_rate=float(log["learningRate"]), - training_loss=float(log["trainLoss"]), - validation_loss=float(log["validationLoss"]), + epoch=float(log["epoch"]) if "epoch" in log and log["epoch"] is not None else None, + step=int(log["step"]) if "step" in log and log["step"] is not None else None, + learning_rate=float(log["learningRate"]) if "learningRate" in log and log["learningRate"] is not None else None, + training_loss=float(log["trainLoss"]) if "trainLoss" in log and log["trainLoss"] is not None else None, + validation_loss=float(log["validationLoss"]) if "validationLoss" in log and log["validationLoss"] is not None else None, ) else: status = FinetuneStatus( From 4322050b672832113fa3d1e903efec731b659bb1 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Thu, 21 Mar 2024 14:56:48 -0300 Subject: [PATCH 4/4] Finetune status checker updates and new unit tests --- aixplain/modules/finetune/status.py | 2 - aixplain/modules/model.py | 44 +++++++++-------- tests/unit/finetune_test.py | 18 ++++--- .../finetune_status_response.json | 10 ++-- .../finetune_status_response_2.json | 49 +++++++++++++++++++ 5 files changed, 87 insertions(+), 36 deletions(-) create mode 100644 tests/unit/mock_responses/finetune_status_response_2.json diff --git a/aixplain/modules/finetune/status.py b/aixplain/modules/finetune/status.py index 89c0cd16..4994ce55 100644 --- a/aixplain/modules/finetune/status.py +++ b/aixplain/modules/finetune/status.py @@ -32,7 +32,5 @@ class FinetuneStatus(object): status: "AssetStatus" model_status: "AssetStatus" epoch: Optional[float] = None - step: Optional[int] = None - learning_rate: Optional[float] = None training_loss: Optional[float] = None validation_loss: Optional[float] = None diff --git a/aixplain/modules/model.py b/aixplain/modules/model.py index 6e2e89e7..fc3a82cd 100644 --- a/aixplain/modules/model.py +++ b/aixplain/modules/model.py @@ -250,12 +250,11 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param response["error"] = msg return response - def check_finetune_status(self, after_epoch: Optional[int] = None, after_step: Optional[int] = None): + def check_finetune_status(self, after_epoch: Optional[int] = None): """Check the status of the FineTune model. Args: after_epoch (Optional[int], optional): status after a given epoch. Defaults to None. - after_step (Optional[int], optional): status after a given step. Defaults to None. Raises: Exception: If the 'TEAM_API_KEY' is not provided. @@ -274,31 +273,34 @@ def check_finetune_status(self, after_epoch: Optional[int] = None, after_step: O resp = r.json() finetune_status = AssetStatus(resp["finetuneStatus"]) model_status = AssetStatus(resp["modelStatus"]) - try: - logs = sorted(resp["logs"], key=lambda x: (float(x["epoch"]), int(x["step"]))) - except Exception: - # if step is not stored - logs = sorted(resp["logs"], key=lambda x: float(x["epoch"])) - - if after_epoch is None and after_step is None: - logs = logs[-1:] - else: - if after_epoch is not None: - logs = [log for log in logs if float(log["epoch"]) >= after_epoch] - if after_step is not None: - if len(logs) > 0 and "step" in logs[0]: - logs = [log for log in logs if log["step"] >= after_step] + logs = sorted(resp["logs"], key=lambda x: float(x["epoch"])) + + target_epoch = None + if after_epoch is not None: + logs = [log for log in logs if float(log["epoch"]) > after_epoch] + if len(logs) > 0: + target_epoch = float(logs[0]["epoch"]) + elif len(logs) > 0: + target_epoch = float(logs[-1]["epoch"]) - if len(logs) > 0: - log = logs[0] + if target_epoch is not None: + log = None + for log_ in logs: + if int(log_["epoch"]) == target_epoch: + if log is None: + log = log_ + else: + if log_["trainLoss"] is not None: + log["trainLoss"] = log_["trainLoss"] + if log_["evalLoss"] is not None: + log["evalLoss"] = log_["evalLoss"] + status = FinetuneStatus( status=finetune_status, model_status=model_status, epoch=float(log["epoch"]) if "epoch" in log and log["epoch"] is not None else None, - step=int(log["step"]) if "step" in log and log["step"] is not None else None, - learning_rate=float(log["learningRate"]) if "learningRate" in log and log["learningRate"] is not None else None, training_loss=float(log["trainLoss"]) if "trainLoss" in log and log["trainLoss"] is not None else None, - validation_loss=float(log["validationLoss"]) if "validationLoss" in log and log["validationLoss"] is not None else None, + validation_loss=float(log["evalLoss"]) if "evalLoss" in log and log["evalLoss"] is not None else None, ) else: status = FinetuneStatus( diff --git a/tests/unit/finetune_test.py b/tests/unit/finetune_test.py index 49f48c70..2c6848b1 100644 --- a/tests/unit/finetune_test.py +++ b/tests/unit/finetune_test.py @@ -37,6 +37,7 @@ FINETUNE_URL = f"{config.BACKEND_URL}/sdk/finetune" FINETUNE_FILE = "tests/unit/mock_responses/finetune_response.json" FINETUNE_STATUS_FILE = "tests/unit/mock_responses/finetune_status_response.json" +FINETUNE_STATUS_FILE_2 = "tests/unit/mock_responses/finetune_status_response_2.json" PERCENTAGE_EXCEPTION_FILE = "tests/unit/data/create_finetune_percentage_exception.json" MODEL_FILE = "tests/unit/mock_responses/model_response.json" MODEL_URL = f"{config.BACKEND_URL}/sdk/models" @@ -97,21 +98,22 @@ def test_start(): assert fine_tuned_model.id == model_map["id"] @pytest.mark.parametrize( - "after_epoch,after_step,training_loss,validation_loss", + "input_path,after_epoch,training_loss,validation_loss", [ - (None, None, 0.4, 0.0217), - (1, 10, 0.1, 0.1106), - (1, 20, 0.2, 0.0482) + (FINETUNE_STATUS_FILE, None, 0.4, 0.0217), + (FINETUNE_STATUS_FILE, 1, 0.2, 0.0482), + (FINETUNE_STATUS_FILE_2, None, 2.657801408034, 2.596168756485), + (FINETUNE_STATUS_FILE_2, 0, None, 2.684150457382) ] ) -def test_check_finetuner_status(after_epoch, after_step, training_loss, validation_loss): - model_map = read_data(FINETUNE_STATUS_FILE) +def test_check_finetuner_status(input_path, after_epoch, training_loss, validation_loss): + model_map = read_data(input_path) asset_id = "test_id" with requests_mock.Mocker() as mock: test_model = Model(asset_id, "") url = urljoin(config.BACKEND_URL, f"sdk/finetune/{asset_id}/ml-logs") mock.get(url, headers=FIXED_HEADER, json=model_map) - status = test_model.check_finetune_status(after_epoch=after_epoch, after_step=after_step) + status = test_model.check_finetune_status(after_epoch=after_epoch) assert status.status.value == model_map["finetuneStatus"] assert status.model_status.value == model_map["modelStatus"] assert status.training_loss == training_loss @@ -132,4 +134,4 @@ def test_list_finetunable_models(is_finetunable): model_list = result["results"] assert len(model_list) > 0 for model_index in range(len(model_list)): - assert model_list[model_index].id == list_map["items"][model_index]["id"] + assert model_list[model_index].id == list_map["items"][model_index]["id"] \ No newline at end of file diff --git a/tests/unit/mock_responses/finetune_status_response.json b/tests/unit/mock_responses/finetune_status_response.json index a8a080d3..9647b164 100644 --- a/tests/unit/mock_responses/finetune_status_response.json +++ b/tests/unit/mock_responses/finetune_status_response.json @@ -6,35 +6,35 @@ "epoch": 1, "learningRate": 9.938725490196079e-05, "trainLoss": 0.1, - "validationLoss": 0.1106, + "evalLoss": 0.1106, "step": 10 }, { "epoch": 2, "learningRate": 9.877450980392157e-05, "trainLoss": 0.2, - "validationLoss": 0.0482, + "evalLoss": 0.0482, "step": 20 }, { "epoch": 3, "learningRate": 9.816176470588235e-05, "trainLoss": 0.3, - "validationLoss": 0.0251, + "evalLoss": 0.0251, "step": 30 }, { "epoch": 4, "learningRate": 9.754901960784314e-05, "trainLoss": 0.9, - "validationLoss": 0.0228, + "evalLoss": 0.0228, "step": 40 }, { "epoch": 5, "learningRate": 9.693627450980392e-05, "trainLoss": 0.4, - "validationLoss": 0.0217, + "evalLoss": 0.0217, "step": 50 } ] diff --git a/tests/unit/mock_responses/finetune_status_response_2.json b/tests/unit/mock_responses/finetune_status_response_2.json new file mode 100644 index 00000000..ea5814a0 --- /dev/null +++ b/tests/unit/mock_responses/finetune_status_response_2.json @@ -0,0 +1,49 @@ +{ + "id": "65fb26268fe9153a6c9c29c4", + "finetuneStatus": "in_progress", + "modelStatus": "training", + "logs": [ + { + "epoch": 1, + "learningRate": null, + "trainLoss": null, + "validationLoss": null, + "step": null, + "evalLoss": 2.684150457382, + "totalFlos": null, + "evalRuntime": 12.4129, + "trainRuntime": null, + "evalStepsPerSecond": 0.322, + "trainStepsPerSecond": null, + "evalSamplesPerSecond": 16.112 + }, + { + "epoch": 2, + "learningRate": null, + "trainLoss": null, + "validationLoss": null, + "step": null, + "evalLoss": 2.596168756485, + "totalFlos": null, + "evalRuntime": 11.8249, + "trainRuntime": null, + "evalStepsPerSecond": 0.338, + "trainStepsPerSecond": null, + "evalSamplesPerSecond": 16.913 + }, + { + "epoch": 2, + "learningRate": null, + "trainLoss": 2.657801408034, + "validationLoss": null, + "step": null, + "evalLoss": null, + "totalFlos": 11893948284928, + "evalRuntime": null, + "trainRuntime": 221.7946, + "evalStepsPerSecond": null, + "trainStepsPerSecond": 0.117, + "evalSamplesPerSecond": null + } + ] +} \ No newline at end of file