Skip to content

Commit

Permalink
Merge pull request #133 from IBM/develop
Browse files Browse the repository at this point in the history
Release 2023.08.21
  • Loading branch information
Tomas2D authored Aug 21, 2023
2 parents 30af41e + 6847416 commit f7680ea
Show file tree
Hide file tree
Showing 15 changed files with 441 additions and 15 deletions.
7 changes: 7 additions & 0 deletions examples/user/prompt_tuning/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ def get_creds():
)
time.sleep(5)

print("~~~~~~~ Downloading tuned model assets ~~~~~")
to_download_assets = input("Download tuned model assets? (y/N):\n")
if to_download_assets == "y":
tuned_model.download()

time.sleep(5)

print("~~~~~~~ Deleting a tuned model ~~~~~")
to_delete = input("Delete this model? (y/N):\n")
if to_delete == "y":
Expand Down
9 changes: 9 additions & 0 deletions examples/user/prompt_tuning/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ def get_creds():
tune_get_result,
)

time.sleep(5)

print("~~~~~~~ Downloading tuned model assets ~~~~~")
to_download_assets = input("Download tuned model assets? (y/N):\n")
if to_download_assets == "y":
tuned_model.download()

time.sleep(5)

print("~~~~~~~ Deleting a tuned model ~~~~~")
to_delete = input("Delete this model? (y/N):\n")
if to_delete == "y":
Expand Down
76 changes: 76 additions & 0 deletions examples/user/prompt_tuning/tune_manager_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
from pathlib import Path

from dotenv import load_dotenv

from genai.credentials import Credentials
from genai.schemas.tunes_params import (
CreateTuneHyperParams,
CreateTuneParams,
DownloadAssetsParams,
TunesListParams,
)
from genai.services import FileManager, TuneManager

load_dotenv()
API_KEY = os.getenv("GENAI_KEY", None)
ENDPOINT = os.getenv("GENAI_API", None)

creds = Credentials(api_key=API_KEY, api_endpoint=ENDPOINT)


print("\n=============== UPLOAD LOCAL FILE FOR TUNING ================")
file_path = str(Path(__file__).parent / "file_to_tune.jsonl")
print(" {} \n".format(file_path))

file_uploaded = FileManager.upload_file(credentials=creds, file_path=file_path, purpose="tune")
print("File uploaded: \n", file_uploaded)


print("\n======================== CREATE TUNE ========================")
file_ids = [file_uploaded.id]
hyperparams = CreateTuneHyperParams(verbalizer='classify { "red", "yellow" } Input: {{input}} Output:')

params = CreateTuneParams(
name="Tune Manager Classification",
model_id="google/flan-t5-xl",
method_id="mpt",
task_id="classification",
training_file_ids=file_ids,
parameters=hyperparams,
)

tune_created = TuneManager.create_tune(credentials=creds, params=params)
print("\nNew tune: \n", tune_created)


print("\n======================= GET TUNE BY ID ======================")
t = tune_created.id
tune_get = TuneManager.get_tune(credentials=creds, tune_id=t)
print("\nGet tune result: \n", tune_get)

print("\n======================== LIST TUNES =========================")
list_params = TunesListParams(limit=5, offset=0)
tune_list = TuneManager.list_tunes(credentials=creds, params=list_params)
for t in tune_list.results:
print("\nTune ID:", t.id, ", Tune Name:", t.name)


print("\n================ LIST EXISTING TUNE METHODS ================")
tune_methods = TuneManager.get_tune_methods(credentials=creds)
print("\n Available Tune Methods:\n")
for t in tune_methods.results:
print("Method ID:", t.id, ", Method name:", t.name)


print("\n=================== DOWNLOAD TUNE ASSETS ===================")
# content can be: logs or enconder. The download will only be available when the tune is complete.
tune = tune_list.results[0].id
assets_params = DownloadAssetsParams(id=tune, content="encoder")
tune_assets = TuneManager.download_tune_assets(credentials=creds, params=assets_params)
print("\n Tune assets:", tune_assets)


print("\n======================== DELETE TUNE ========================")
tune_delete = TuneManager.delete_tune(credentials=creds, tune_id=tune)
print("\nDelete tune response: \n", tune_delete)
44 changes: 44 additions & 0 deletions examples/user/prompt_tuning/upload_file_create_tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
from pathlib import Path

from dotenv import load_dotenv

from genai.credentials import Credentials
from genai.schemas.tunes_params import CreateTuneParams
from genai.services import FileManager, TuneManager

load_dotenv()
API_KEY = os.getenv("GENAI_KEY", None)
ENDPOINT = os.getenv("GENAI_API", None)

creds = Credentials(api_key=API_KEY, api_endpoint=ENDPOINT)

# # UPLOAD FILE
# file_path = "<path-to-file>"
print("======================== UPLOAD LOCAL FILE ========================")
file_path = str(Path(__file__).parent / "file_to_tune.jsonl")
print(" {} \n".format(file_path))

upload_file = FileManager.upload_file(credentials=creds, file_path=file_path, purpose="tune")
print("File uploaded: \n", upload_file)

print("\n===================== GET UPLOADED FILE ID ========================")
file_list = FileManager.list_files(credentials=creds)
for f in file_list.results:
if f.file_name == "file_to_tune.jsonl":
file_id = f.id
break
print("Uploaded file has id = {}\n".format(file_id))

print("\n================= CREATE TUNE FOR GENERATION TASK ==================")

tunes_params = CreateTuneParams(
name="flan-t5-xl",
model_id="google/flan-t5-xl",
method_id="pt",
task_id="generation",
training_file_ids=[file_id],
)

tune_create = TuneManager.create_tune(credentials=creds, params=tunes_params)
print("Created tune has the id: \n ", tune_create.id)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ readme = "README.md"
dependencies = [
"urllib3<2", # https://github.com/psf/requests/issues/6432
"requests>=2.31.0",
"pydantic<=1.10.10",
"pydantic>=1.10.10,<2",
"python-dotenv>=1.0.0",
"aiohttp>=3.8.4",
"pyyaml>=0.2.5",
Expand Down
9 changes: 9 additions & 0 deletions src/genai/exceptions/genai_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,14 @@ def __init__(self, error: Union[Exception, Response]) -> None:
else:
self.error = error
self.error_message = str(error)
if "TOU_NOT_ACCEPTED" in self.error_message:
split_message = self.error_message.split("Terms of use not accepted")
self.error_message = "".join(
[
split_message[0],
"Terms of use not accepted. Please accept the terms of use in a browser.",
split_message[1],
]
)
logger.error(self.error_message)
super().__init__(self.error_message)
4 changes: 3 additions & 1 deletion src/genai/extensions/localserver/local_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,7 @@ async def _route_generate(self, generate_request: GenerateRequestBody):
for input in generate_request.inputs
]
created_at = datetime.datetime.now().isoformat()
response = GenerateResponse(model_id=generate_request.model_id, created_at=created_at, results=results)
response = GenerateResponse(
id=str(uuid.uuid4()), model_id=generate_request.model_id, created_at=created_at, results=results
)
return response
7 changes: 7 additions & 0 deletions src/genai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from genai.schemas.tunes_params import (
CreateTuneHyperParams,
CreateTuneParams,
DownloadAssetsParams,
TunesListParams,
)
from genai.services import AsyncResponseGenerator, ServiceInterface
Expand Down Expand Up @@ -367,6 +368,12 @@ def delete(self):
raise GenAiException(ValueError("Tuned model not found. Currently method supports only tuned models."))
TuneManager.delete_tune(service=self.service, tune_id=self.model)

def download(self):
enconder_params = DownloadAssetsParams(id=self.model, content="encoder")
logs_params = DownloadAssetsParams(id=self.model, content="logs")
TuneManager.download_tune_assets(service=self.service, params=enconder_params)
TuneManager.download_tune_assets(service=self.service, params=logs_params)

@staticmethod
def models(credentials: Credentials = None, service: ServiceInterface = None) -> list[ModelCard]:
"""Get a list of models
Expand Down
30 changes: 29 additions & 1 deletion src/genai/routers/tunes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from genai.exceptions import GenAiException
from genai.schemas.tunes_params import CreateTuneParams, TunesListParams
from genai.schemas.tunes_params import (
CreateTuneParams,
DownloadAssetsParams,
TunesListParams,
)
from genai.services.request_handler import RequestHandler
from genai.utils.request_utils import sanitize_params

Expand Down Expand Up @@ -73,3 +77,27 @@ def delete_tune(self, tune_id: str):
return RequestHandler.delete(endpoint, key=self.key, parameters=tune_id)
except Exception as e:
raise GenAiException(e)

def get_tune_methods(self):
"""Get list of tune methods.
Returns:
Any: json with info about the available tune methods.
"""
try:
endpoint = self.service_url + "/tune_methods"
return RequestHandler.get(endpoint, key=self.key)
except Exception as e:
raise GenAiException(e)

def download_tune_assets(self, params: DownloadAssetsParams):
"""Download tune asset.
Returns:
Any: json with info about the downloaded tune asset.
"""
try:
endpoint = self.service_url + TunesRouter.TUNES + "/" + params.id + "/content/" + params.content
return RequestHandler.get(endpoint, key=self.key)
except Exception as e:
raise GenAiException(e)
2 changes: 2 additions & 0 deletions src/genai/schemas/descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class TunesAPIDescriptions:
STATUS = "Filters the items to be returned based on their status. Possible values are: INITIALIZING, NOT_STARTED, PENDING, HALTED, RUNNING, QUEUED, COMPLETED, FAILED."
INIT_METHOD = "Initialization method to be used. Possible values are RANDOM or TEXT. Defaults to RANDOM. Used only if the method_id is 'pt' = Prompt Tuning."
INIT_TEXT = "Initialization text to be used. This is only applicable if init_method == TEXT. Used only if the method_id is 'pt' = Prompt Tuning."
ID = "The ID of the tune."
CONTENT = "The name of the asset. Available options are encoder and logs."


class FilesAPIDescriptions:
Expand Down
57 changes: 57 additions & 0 deletions src/genai/schemas/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class GenerateResult(GenAiResponseModel):


class GenerateResponse(GenAiResponseModel):
id: str
model_id: str
created_at: datetime
results: List[GenerateResult]
Expand Down Expand Up @@ -228,6 +229,62 @@ class TuneGetResponse(GenAiResponseModel):
results: Optional[TuneInfoResult]


class TuneMethodsInfo(GenAiResponseModel):
id: str
name: str


class TuneMethodsGetResponse(GenAiResponseModel):
results: Optional[List[TuneMethodsInfo]]


class FileFormatResult(GenAiResponseModel):
id: int
name: str


class FileInfoResult(GenAiResponseModel):
id: str
bytes: str
file_name: str
purpose: str
storage_provider_location: Optional[str]
created_at: datetime
file_formats: List[FileFormatResult]


class TuneParameters(GenAiResponseModel):
accumulate_steps: Optional[int]
batch_size: Optional[int]
learning_rate: Optional[float]
max_input_tokens: Optional[int]
max_output_tokens: Optional[int]
num_epochs: Optional[int]
num_virtual_tokens: Optional[int]
verbalizer: Optional[str]


class TuneInfoResult(GenAiResponseModel):
id: str
name: str
model_id: str
model_name: str
method_id: Optional[str]
method_name: Optional[str]
status: str
task_id: str
task_name: Optional[str]
parameters: Optional[TuneParameters]
created_at: datetime
preferred: Optional[bool]
datapoints: Optional[dict]
validation_files: Optional[list]
training_files: Optional[list]
evaluation_files: Optional[list]
status_message: Optional[str]
started_at: Optional[datetime]


class ModelCard(GenAiResponseModel):
id: Optional[str]
name: Optional[str]
Expand Down
11 changes: 11 additions & 0 deletions src/genai/schemas/tunes_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,14 @@ class Config:
training_file_ids: list[str] = Field(None, description=tx.TRAINING_FILE_IDS)
validation_file_ids: Optional[list[str]] = Field(None, description=tx.VALIDATION_FILE_IDS)
parameters: Optional[CreateTuneHyperParams] = Field(None, description=tx.PARAMETERS)


class DownloadAssetsParams(BaseModel):
"""Class to hold the parameters for downloading tune assets."""

class Config:
anystr_strip_whitespace = True
# extra: Extra.forbid

id: str = Field(None, description=tx.ID)
content: str = Field(None, description=tx.CONTENT)
Loading

0 comments on commit f7680ea

Please sign in to comment.