Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions aixplain/enums/asset_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
__author__ = "thiagocastroferreira"

"""
Copyright 2024 The aiXplain SDK authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli
Date: February 21st 2024
Description:
Asset Enum
"""

from enum import Enum
from typing import Text

class AssetStatus(Text, Enum):
HIDDEN = 'hidden'
SCHEDULED = 'scheduled'
ONBOARDING = 'onboarding'
ONBOARDED = 'onboarded'
PENDING = 'pending'
FAILED = 'failed'
TRAINING = 'training'
REJECTED = 'rejected'
ENABLING = 'enabling'
DELETING = 'deleting'
DISABLED = 'disabled'
DELETED = 'deleted'
IN_PROGRESS = 'in_progress'
COMPLETED = 'completed'
CANCELING = 'canceling'
CANCELED = 'canceled'
1 change: 1 addition & 0 deletions aixplain/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@
from .model import Model
from .pipeline import Pipeline
from .finetune import Finetune, FinetuneCost
from .finetune.status import FinetuneStatus
from .benchmark import Benchmark
from .benchmark_job import BenchmarkJob
36 changes: 36 additions & 0 deletions aixplain/modules/finetune/status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
__author__ = "thiagocastroferreira"

"""
Copyright 2024 The aiXplain SDK authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli
Date: February 21st 2024
Description:
FinetuneCost Class
"""

from aixplain.enums.asset_status import AssetStatus
from dataclasses import dataclass
from dataclasses_json import dataclass_json
from typing import Optional, Text

@dataclass_json
@dataclass
class FinetuneStatus(object):
status: "AssetStatus"
model_status: "AssetStatus"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it should be:

status: AssetStatus
model_status: AssetStatus

epoch: Optional[float] = None
training_loss: Optional[float] = None
validation_loss: Optional[float] = None
2 changes: 2 additions & 0 deletions aixplain/modules/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from typing import Optional, Text, List, Union
from aixplain.modules.asset import Asset

from aixplain.utils.file_utils import _request_with_retry
# from aixplain.factories.model_factory import ModelFactory

Expand Down Expand Up @@ -92,6 +93,7 @@ def run(self, hypothesis: Optional[Union[str, List[str]]]=None, source: Optional
source (Optional[Union[str, List[str]]], optional): Can give a single source or a list of sources for metric calculation. Defaults to None.
reference (Optional[Union[str, List[str]]], optional): Can give a single reference or a list of references for metric calculation. Defaults to None.
"""
from aixplain.factories.model_factory import ModelFactory
model = ModelFactory.get(self.id)
payload = {
"function": self.function,
Expand Down
53 changes: 47 additions & 6 deletions aixplain/modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Description:
Model Class
"""

import time
import json
import logging
Expand Down Expand Up @@ -251,23 +250,65 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param
response["error"] = msg
return response

def check_finetune_status(self):
def check_finetune_status(self, after_epoch: Optional[int] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add FinetuneStatus in the return annotation:

def check_finetune_status(self, after_epoch: Optional[int] = None) -> FinetuneStatus:

"""Check the status of the FineTune model.

Args:
after_epoch (Optional[int], optional): status after a given epoch. Defaults to None.

Raises:
Exception: If the 'TEAM_API_KEY' is not provided.

Returns:
str: The status of the FineTune model.
FinetuneStatus: The status of the FineTune model.
"""
from aixplain.enums.asset_status import AssetStatus
from aixplain.modules.finetune.status import FinetuneStatus
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
resp = None
try:
url = urljoin(self.backend_url, f"sdk/models/{self.id}")
url = urljoin(self.backend_url, f"sdk/finetune/{self.id}/ml-logs")
logging.info(f"Start service for GET Check FineTune status Model - {url} - {headers}")
r = _request_with_retry("get", url, headers=headers)
resp = r.json()
status = resp["status"]
logging.info(f"Response for GET Check FineTune status Model - Id {self.id} / Status {status}.")
finetune_status = AssetStatus(resp["finetuneStatus"])
model_status = AssetStatus(resp["modelStatus"])
logs = sorted(resp["logs"], key=lambda x: float(x["epoch"]))

target_epoch = None
if after_epoch is not None:
logs = [log for log in logs if float(log["epoch"]) > after_epoch]
if len(logs) > 0:
target_epoch = float(logs[0]["epoch"])
elif len(logs) > 0:
target_epoch = float(logs[-1]["epoch"])

if target_epoch is not None:
log = None
for log_ in logs:
if int(log_["epoch"]) == target_epoch:
if log is None:
log = log_
else:
if log_["trainLoss"] is not None:
log["trainLoss"] = log_["trainLoss"]
if log_["evalLoss"] is not None:
log["evalLoss"] = log_["evalLoss"]

status = FinetuneStatus(
status=finetune_status,
model_status=model_status,
epoch=float(log["epoch"]) if "epoch" in log and log["epoch"] is not None else None,
training_loss=float(log["trainLoss"]) if "trainLoss" in log and log["trainLoss"] is not None else None,
validation_loss=float(log["evalLoss"]) if "evalLoss" in log and log["evalLoss"] is not None else None,
)
else:
status = FinetuneStatus(
status=finetune_status,
model_status=model_status,
)

logging.info(f"Response for GET Check FineTune status Model - Id {self.id} / Status {status.status.value}.")
return status
except Exception as e:
message = ""
Expand Down
28 changes: 21 additions & 7 deletions tests/unit/finetune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from aixplain.modules import Model, Finetune
from aixplain.modules.finetune import Hyperparameters
from aixplain.enums import Function
from urllib.parse import urljoin

import pytest

Expand All @@ -37,6 +38,8 @@
COST_ESTIMATION_FILE = "tests/unit/mock_responses/cost_estimation_response.json"
FINETUNE_URL = f"{config.BACKEND_URL}/sdk/finetune"
FINETUNE_FILE = "tests/unit/mock_responses/finetune_response.json"
FINETUNE_STATUS_FILE = "tests/unit/mock_responses/finetune_status_response.json"
FINETUNE_STATUS_FILE_2 = "tests/unit/mock_responses/finetune_status_response_2.json"
PERCENTAGE_EXCEPTION_FILE = "tests/unit/data/create_finetune_percentage_exception.json"
MODEL_FILE = "tests/unit/mock_responses/model_response.json"
MODEL_URL = f"{config.BACKEND_URL}/sdk/models"
Expand Down Expand Up @@ -106,16 +109,27 @@ def test_start():
assert fine_tuned_model is not None
assert fine_tuned_model.id == model_map["id"]


def test_check_finetuner_status():
model_map = read_data(MODEL_FILE)
@pytest.mark.parametrize(
"input_path,after_epoch,training_loss,validation_loss",
[
(FINETUNE_STATUS_FILE, None, 0.4, 0.0217),
(FINETUNE_STATUS_FILE, 1, 0.2, 0.0482),
(FINETUNE_STATUS_FILE_2, None, 2.657801408034, 2.596168756485),
(FINETUNE_STATUS_FILE_2, 0, None, 2.684150457382)
]
)
def test_check_finetuner_status(input_path, after_epoch, training_loss, validation_loss):
model_map = read_data(input_path)
asset_id = "test_id"
with requests_mock.Mocker() as mock:
test_model = Model(asset_id, "")
url = f"{MODEL_URL}/{asset_id}"
url = urljoin(config.BACKEND_URL, f"sdk/finetune/{asset_id}/ml-logs")
mock.get(url, headers=FIXED_HEADER, json=model_map)
status = test_model.check_finetune_status()
assert status == model_map["status"]
status = test_model.check_finetune_status(after_epoch=after_epoch)
assert status.status.value == model_map["finetuneStatus"]
assert status.model_status.value == model_map["modelStatus"]
assert status.training_loss == training_loss
assert status.validation_loss == validation_loss


@pytest.mark.parametrize("is_finetunable", [True, False])
Expand All @@ -132,4 +146,4 @@ def test_list_finetunable_models(is_finetunable):
model_list = result["results"]
assert len(model_list) > 0
for model_index in range(len(model_list)):
assert model_list[model_index].id == list_map["items"][model_index]["id"]
assert model_list[model_index].id == list_map["items"][model_index]["id"]
41 changes: 41 additions & 0 deletions tests/unit/mock_responses/finetune_status_response.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"finetuneStatus": "onboarding",
"modelStatus": "onboarded",
"logs": [
{
"epoch": 1,
"learningRate": 9.938725490196079e-05,
"trainLoss": 0.1,
"evalLoss": 0.1106,
"step": 10
},
{
"epoch": 2,
"learningRate": 9.877450980392157e-05,
"trainLoss": 0.2,
"evalLoss": 0.0482,
"step": 20
},
{
"epoch": 3,
"learningRate": 9.816176470588235e-05,
"trainLoss": 0.3,
"evalLoss": 0.0251,
"step": 30
},
{
"epoch": 4,
"learningRate": 9.754901960784314e-05,
"trainLoss": 0.9,
"evalLoss": 0.0228,
"step": 40
},
{
"epoch": 5,
"learningRate": 9.693627450980392e-05,
"trainLoss": 0.4,
"evalLoss": 0.0217,
"step": 50
}
]
}
49 changes: 49 additions & 0 deletions tests/unit/mock_responses/finetune_status_response_2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
{
"id": "65fb26268fe9153a6c9c29c4",
"finetuneStatus": "in_progress",
"modelStatus": "training",
"logs": [
{
"epoch": 1,
"learningRate": null,
"trainLoss": null,
"validationLoss": null,
"step": null,
"evalLoss": 2.684150457382,
"totalFlos": null,
"evalRuntime": 12.4129,
"trainRuntime": null,
"evalStepsPerSecond": 0.322,
"trainStepsPerSecond": null,
"evalSamplesPerSecond": 16.112
},
{
"epoch": 2,
"learningRate": null,
"trainLoss": null,
"validationLoss": null,
"step": null,
"evalLoss": 2.596168756485,
"totalFlos": null,
"evalRuntime": 11.8249,
"trainRuntime": null,
"evalStepsPerSecond": 0.338,
"trainStepsPerSecond": null,
"evalSamplesPerSecond": 16.913
},
{
"epoch": 2,
"learningRate": null,
"trainLoss": 2.657801408034,
"validationLoss": null,
"step": null,
"evalLoss": null,
"totalFlos": 11893948284928,
"evalRuntime": null,
"trainRuntime": 221.7946,
"evalStepsPerSecond": null,
"trainStepsPerSecond": 0.117,
"evalSamplesPerSecond": null
}
]
}