Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
import logging
from typing import Dict, List, Optional, Text
import json
from aixplain.factories.finetune_factory.prompt_validator import validate_prompt
from aixplain.modules.finetune import Finetune
from aixplain.modules.finetune.cost import FinetuneCost
from aixplain.modules.finetune.hyperparameters import Hyperparameters
from aixplain.modules.finetune.peft import Peft
from aixplain.modules.dataset import Dataset
from aixplain.modules.model import Model
from aixplain.modules.finetune import Finetune
from aixplain.modules.finetune_cost import FinetuneCost
from aixplain.utils import config
from aixplain.utils.file_utils import _request_with_retry
from urllib.parse import urljoin
Expand Down Expand Up @@ -59,17 +62,27 @@ def _create_cost_from_response(cls, response: Dict) -> FinetuneCost:

@classmethod
def create(
cls, name: Text, dataset_list: List[Dataset], model: Model, train_percentage: float = 100, dev_percentage: float = 0
cls,
name: Text,
dataset_list: List[Dataset],
model: Model,
prompt: Optional[Text] = None,
hyperparameters: Optional[Hyperparameters] = None,
peft: Optional[Peft] = None,
train_percentage: Optional[float] = 100,
dev_percentage: Optional[float] = 0,
) -> Finetune:
"""Create a Finetune object with the provided information.

Args:
name (Text): Name of the Finetune.
dataset_list (List[Dataset]): List of Datasets to be used for fine-tuning.
model (Model): Model to be fine-tuned.
prompt (Text, optional): Fine-tuning prompt. Defaults to None.
hyperparameters (Hyperparameters, optional): Hyperparameters for fine-tuning. Defaults to None.
peft (Peft, optional): PEFT (Parameter-Efficient Fine-Tuning) configuration. Defaults to None.
train_percentage (float, optional): Percentage of training samples. Defaults to 100.
dev_percentage (float, optional): Percentage of development samples. Defaults to 0.

Returns:
Finetune: The Finetune object created with the provided information or None if there was an error.
"""
Expand All @@ -78,24 +91,42 @@ def create(
assert (
train_percentage + dev_percentage <= 100
), f"Create FineTune: Train percentage + dev percentage ({train_percentage + dev_percentage}) must be less than or equal to one"
if prompt is not None:
prompt = validate_prompt(prompt, dataset_list)
try:
url = urljoin(cls.backend_url, f"sdk/finetune/cost-estimation")
headers = {"Authorization": f"Token {cls.api_key}", "Content-Type": "application/json"}
payload = json.dumps(
{
"datasets": [
{"datasetId": dataset.id, "trainPercentage": train_percentage, "devPercentage": dev_percentage}
for dataset in dataset_list
],
"sourceModelId": model.id,
}
)
payload = {
"datasets": [
{"datasetId": dataset.id, "trainPercentage": train_percentage, "devPercentage": dev_percentage}
for dataset in dataset_list
],
"sourceModelId": model.id,
}
parameters = {}
if prompt is not None:
parameters["prompt"] = prompt
if hyperparameters is not None:
parameters["hyperparameters"] = hyperparameters.to_dict()
if peft is not None:
parameters["peft"] = peft.to_dict()
payload["parameters"] = parameters
logging.info(f"Start service for POST Create FineTune - {url} - {headers} - {json.dumps(payload)}")
r = _request_with_retry("post", url, headers=headers, data=payload)
r = _request_with_retry("post", url, headers=headers, json=payload)
resp = r.json()
logging.info(f"Response for POST Create FineTune - Status {resp}")
cost = cls._create_cost_from_response(resp)
return Finetune(name, dataset_list, model, cost, train_percentage=train_percentage, dev_percentage=dev_percentage)
return Finetune(
name,
dataset_list,
model,
cost,
train_percentage=train_percentage,
dev_percentage=dev_percentage,
prompt=prompt,
hyperparameters=hyperparameters,
peft=peft,
)
except Exception:
error_message = f"Create FineTune: Error with payload {json.dumps(payload)}"
logging.exception(error_message)
Expand Down
41 changes: 41 additions & 0 deletions aixplain/factories/finetune_factory/prompt_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import List, Text
from aixplain.modules.dataset import Dataset
import re


def _get_data_list(dataset: Dataset):
flatten_target_values = [item for sublist in list(dataset.target_data.values()) for item in sublist]
data_list = list(dataset.source_data.values()) + flatten_target_values
return data_list


def validate_prompt(prompt: Text, dataset_list: List[Dataset]) -> Text:
result_prompt = prompt
referenced_data = set(re.findall("<<(.+?)>>", prompt))
for dataset in dataset_list:
data_list = _get_data_list(dataset)
for data in data_list:
if data.id in referenced_data:
result_prompt = result_prompt.replace(f"<<{data.id}>>", f"<<{data.name}>>")
referenced_data.remove(data.id)
referenced_data.add(data.name)

# check if dataset list has same data name and it is referenced
name_set = set()
for dataset in dataset_list:
data_list = _get_data_list(dataset)
for data in data_list:
assert not (
data.name in name_set and data.name in referenced_data
), "Datasets must not have more than one referenced data with same name"
name_set.add(data.name)

# check if all referenced data have a respective data in dataset list
for dataset in dataset_list:
data_list = _get_data_list(dataset)
for data in data_list:
if data.name in referenced_data:
result_prompt = result_prompt.replace(f"<<{data.name}>>", f"{{{data.name}}}")
referenced_data.remove(data.name)
assert len(referenced_data) == 0, "Referenced data are not present in dataset list"
return result_prompt
2 changes: 1 addition & 1 deletion aixplain/factories/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model:
headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"}
else:
headers = {"Authorization": f"Token {cls.api_key}", "Content-Type": "application/json"}
logging.info(f"Start service for GET Metric - {url} - {headers}")
logging.info(f"Start service for GET Model - {url} - {headers}")
r = _request_with_retry("get", url, headers=headers)
resp = r.json()
# set api key
Expand Down
3 changes: 1 addition & 2 deletions aixplain/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from .metric import Metric
from .model import Model
from .pipeline import Pipeline
from .finetune import Finetune
from .finetune import Finetune, FinetuneCost
from .benchmark import Benchmark
from .benchmark_job import BenchmarkJob
from .finetune_cost import FinetuneCost
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@
Description:
FineTune Class
"""
from typing import List, Text
from typing import List, Text, Optional
import logging
from aixplain.utils.file_utils import _request_with_retry
import json
from urllib.parse import urljoin
from aixplain.utils import config
from aixplain.modules.finetune.cost import FinetuneCost
from aixplain.modules.finetune.hyperparameters import Hyperparameters
from aixplain.modules.finetune.peft import Peft
from aixplain.factories.model_factory import ModelFactory
from aixplain.modules.asset import Asset
from aixplain.modules.dataset import Dataset
from aixplain.modules.model import Model
from aixplain.modules.finetune_cost import FinetuneCost

from aixplain.utils import config
from aixplain.utils.file_utils import _request_with_retry


class Finetune(Asset):
Expand All @@ -47,6 +50,9 @@ class Finetune(Asset):
version (Text): Version of the FineTune.
train_percentage (float): Percentage of training samples.
dev_percentage (float): Percentage of development samples.
prompt (Text): Fine-tuning prompt.
hyperparameters (Hyperparameters): Hyperparameters for fine-tuning.
peft (Peft): PEFT (Parameter-Efficient Fine-Tuning) configuration.
additional_info (dict): Additional information to be saved with the FineTune.
backend_url (str): URL of the backend.
api_key (str): The TEAM API key used for authentication.
Expand All @@ -58,12 +64,15 @@ def __init__(
dataset_list: List[Dataset],
model: Model,
cost: FinetuneCost,
id: Text = "",
description: Text = "",
supplier: Text = "aiXplain",
version: Text = "1.0",
train_percentage: float = 100,
dev_percentage: float = 0,
id: Optional[Text] = "",
description: Optional[Text] = "",
supplier: Optional[Text] = "aiXplain",
version: Optional[Text] = "1.0",
train_percentage: Optional[float] = 100,
dev_percentage: Optional[float] = 0,
prompt: Optional[Text] = None,
hyperparameters: Optional[Hyperparameters] = None,
peft: Optional[Peft] = None,
**additional_info,
) -> None:
"""Create a FineTune with the necessary information.
Expand All @@ -79,6 +88,9 @@ def __init__(
version (Text, optional): Version of the FineTune. Defaults to "1.0".
train_percentage (float, optional): Percentage of training samples. Defaults to 100.
dev_percentage (float, optional): Percentage of development samples. Defaults to 0.
prompt (Text, optional): Fine-tuning prompt. Defaults to None.
hyperparameters (Hyperparameters, optional): Hyperparameters for fine-tuning. Defaults to None.
peft (Peft, optional): PEFT (Parameter-Efficient Fine-Tuning) configuration. Defaults to None.
**additional_info: Additional information to be saved with the FineTune.
"""
super().__init__(id, name, description, supplier, version)
Expand All @@ -87,6 +99,9 @@ def __init__(
self.cost = cost
self.train_percentage = train_percentage
self.dev_percentage = dev_percentage
self.prompt = prompt
self.hyperparameters = hyperparameters
self.peft = peft
self.additional_info = additional_info
self.backend_url = config.BACKEND_URL
self.api_key = config.TEAM_API_KEY
Expand All @@ -102,22 +117,28 @@ def start(self) -> Model:
try:
url = urljoin(self.backend_url, f"sdk/finetune")
headers = {"Authorization": f"Token {self.api_key}", "Content-Type": "application/json"}
payload = json.dumps(
{
"name": self.name,
"datasets": [
{
"datasetId": dataset.id,
"trainSamplesPercentage": self.train_percentage,
"devSamplesPercentage": self.dev_percentage,
}
for dataset in self.dataset_list
],
"sourceModelId": self.model.id,
}
)
payload = {
"name": self.name,
"datasets": [
{
"datasetId": dataset.id,
"trainSamplesPercentage": self.train_percentage,
"devSamplesPercentage": self.dev_percentage,
}
for dataset in self.dataset_list
],
"sourceModelId": self.model.id,
}
parameters = {}
if self.prompt is not None:
parameters["prompt"] = self.prompt
if self.hyperparameters is not None:
parameters["hyperparameters"] = self.hyperparameters.to_dict()
if self.peft is not None:
parameters["peft"] = self.peft.to_dict()
payload["parameters"] = parameters
logging.info(f"Start service for POST Start FineTune - {url} - {headers} - {json.dumps(payload)}")
r = _request_with_retry("post", url, headers=headers, data=payload)
r = _request_with_retry("post", url, headers=headers, json=payload)
resp = r.json()
logging.info(f"Response for POST Start FineTune - Name: {self.name} / Status {resp}")
return ModelFactory().get(resp["id"])
Expand Down
File renamed without changes.
17 changes: 17 additions & 0 deletions aixplain/modules/finetune/hyperparameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dataclasses import dataclass
from dataclasses_json import dataclass_json


@dataclass_json
@dataclass
class Hyperparameters(object):
epochs: int = 4
train_batch_size: int = 4
eval_batch_size: int = 4
learning_rate: float = 2e-5
warmup_steps: int = 500
generation_max_length: int = 225
tokenizer_batch_size: int = 256
gradient_checkpointing: bool = False
gradient_accumulation_steps: int = 1
max_seq_length: int = 4096
10 changes: 10 additions & 0 deletions aixplain/modules/finetune/peft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass
from dataclasses_json import dataclass_json


@dataclass_json
@dataclass
class Peft(object):
peft_lora_r: int = 8
peft_lora_alpha: int = 32
peft_lora_dropout: float = 0.05
14 changes: 14 additions & 0 deletions aixplain/modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,17 @@ def check_finetune_status(self):
message = f"Status {status_code} - {message}"
error_message = f"Check FineTune status Model: Error {message}"
logging.exception(error_message)

def delete(self) -> None:
"""Delete Model service"""
try:
url = urljoin(self.backend_url, f"sdk/models/{self.id}")
headers = {"Authorization": f"Token {self.api_key}", "Content-Type": "application/json"}
logging.info(f"Start service for DELETE Model - {url} - {headers}")
r = _request_with_retry("delete", url, headers=headers)
if r.status_code != 200:
raise Exception()
except Exception:
message = "Model Deletion Error: Make sure the model exists and you are the owner."
logging.error(message)
raise Exception(f"{message}")
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ dependencies = [
"validators>=0.20.0",
"filetype>=1.2.0",
"click>=8.1.7",
"PyYAML>=6.0.1"
"PyYAML>=6.0.1",
"dataclasses-json==0.6.1"
]

[project.urls]
Expand Down
12 changes: 12 additions & 0 deletions tests/functional/finetune/data/finetune_test_end2end.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[
{
"model_name": "Chat GPT 3.5",
"dataset_name": "Test text generation dataset",
"inference_data": "Hello!"
},
{
"model_name": "GPT2",
"dataset_name": "Test text generation dataset",
"inference_data": "Hello!"
}
]
8 changes: 1 addition & 7 deletions tests/functional/finetune/data/finetune_test_list_data.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
[
{
"function": "translation",
"source_language": {"language": "en", "dialect": ""},
"target_language": {"language": "fr", "dialect": ""}
},
{
"function": "speech-recognition",
"source_language": {"language": "en", "dialect": ""}
"function": "text-generation"
}
]
Loading