Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tunes api #102

Merged
merged 80 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from 78 commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
dd8b702
Merge pull request #1 from IBM/develop
onkarbhardwaj Jun 22, 2023
1ae83f9
Merge branch 'IBM:develop' into develop
onkarbhardwaj Jul 3, 2023
b353c3e
init prompt tuning: add examples -files and tunes
mirianfsilva Jul 3, 2023
eade6f7
add schemas
mirianfsilva Jul 3, 2023
cce1c01
init file services
mirianfsilva Jul 3, 2023
9d85108
init tune services
mirianfsilva Jul 3, 2023
edda015
feat: tune manager
moneill0 Jul 3, 2023
7878761
fix: fixed tunes router
moneill0 Jul 3, 2023
e36af10
fix: fix merge conflicts
moneill0 Jul 3, 2023
e79844d
update imports from file example, update routers
mirianfsilva Jul 3, 2023
afd0571
add tests for files
mirianfsilva Jul 4, 2023
1dcec66
update tunes doc-strings
mirianfsilva Jul 4, 2023
72eb7ce
request handler changes - add files to httpx POST methods
mirianfsilva Jul 4, 2023
c80824f
update files docstrings, improve tests
mirianfsilva Jul 4, 2023
1cc0390
fix delete file tests
mirianfsilva Jul 4, 2023
46a46f2
add tunes tests
mirianfsilva Jul 4, 2023
ddae508
tunes examples updates
mirianfsilva Jul 4, 2023
9e0ddd4
add new file-tune example
mirianfsilva Jul 4, 2023
860a869
add descriptions and adjust imports
onkarbhardwaj Jul 5, 2023
1056493
feat: get tune methods function
moneill0 Jul 6, 2023
ba58545
fix: fixed get methods endpoint url
moneill0 Jul 6, 2023
b0f4b76
fix pre-commit
mirianfsilva Jul 6, 2023
ca5dc1a
tune() and status() methods for model tuning
Jul 6, 2023
1e697ee
documentation for status() method
Jul 6, 2023
9e7869d
feat: download tune assets function
moneill0 Jul 6, 2023
74fb5d7
unit tests
Jul 6, 2023
294c1d4
test: tune manager pytests
moneill0 Jul 6, 2023
26b5373
improve response helper for tests
mirianfsilva Jul 7, 2023
2500b87
fix: fixed get tunes schema check
moneill0 Jul 7, 2023
268eb76
Merge pull request #2 from onkarbhardwaj/tune_manager_pytest
mirianfsilva Jul 7, 2023
e40ea67
changing type in schemas for some fields
Jul 7, 2023
3cc9864
Merge pull request #3 from onkarbhardwaj/feature/prompt-tuning-model
onkarbhardwaj Jul 7, 2023
57d31f4
Prompt tuning for classification
Jul 7, 2023
1a00e22
delete tune and change paramter name
Jul 7, 2023
f1b0b49
Merge pull request #4 from onkarbhardwaj/feature/prompt-tuning-model
onkarbhardwaj Jul 7, 2023
f057609
accept ServiceInterface as credentials - FileManager
mirianfsilva Jul 7, 2023
fdad69e
accept ServiceInterface as credentials - TuneManager
mirianfsilva Jul 10, 2023
5f8661e
Merge pull request #5 from onkarbhardwaj/init-service
mirianfsilva Jul 10, 2023
99a1e15
update examples
mirianfsilva Jul 10, 2023
e879b06
FileManager and TuneManager can take both creds and servie interface
Jul 10, 2023
ec7f9c6
align examples with changed method signatures
Jul 10, 2023
6faf133
update schemas, docs strings, error messages
mirianfsilva Jul 10, 2023
20701a5
Examples for classification and summarization
Jul 10, 2023
91670ae
Revert "accept ServiceInterface as credentials - FileManager, TuneMan…
onkarbhardwaj Jul 10, 2023
0597c25
Merge pull request #7 from onkarbhardwaj/revert-5-init-service
onkarbhardwaj Jul 10, 2023
fad838c
Merge pull request #8 from onkarbhardwaj/feature/prompt-tuning-model
onkarbhardwaj Jul 10, 2023
a9c2bb6
remove examples that were superseded
Jul 10, 2023
4095d5c
Merge pull request #9 from onkarbhardwaj/feature/prompt-tuning-model
onkarbhardwaj Jul 10, 2023
8ad6014
update branch
mirianfsilva Jul 10, 2023
8444216
update docstrings
mirianfsilva Jul 10, 2023
4c86a1a
wrote examples for download assets
moneill0 Jul 10, 2023
4087147
update schema
mirianfsilva Jul 10, 2023
ca84efd
Merge pull request #6 from onkarbhardwaj/review
mirianfsilva Jul 10, 2023
8d9250f
update tests
mirianfsilva Jul 10, 2023
d56564c
Merge branch 'IBM:develop' into develop
onkarbhardwaj Jul 10, 2023
08678f7
Merge pull request #11 from onkarbhardwaj/develop
onkarbhardwaj Jul 10, 2023
435ec16
Moved sanitize param to utils folder to remove circular import
Jul 10, 2023
f81b360
Merge pull request #12 from onkarbhardwaj/feature/prompt-tuning-vdemers
onkarbhardwaj Jul 10, 2023
4ae4c39
docstrings and function signatures
Jul 10, 2023
43421db
Merge pull request #13 from onkarbhardwaj/fixes/docstrings-signatures
onkarbhardwaj Jul 10, 2023
53d79ac
Merge branch 'feature/prompt-tuning' into update_tunes_api
moneill0 Jul 11, 2023
1be3fc3
test: pytest for get tune methods
moneill0 Jul 11, 2023
076dd05
Merge branch 'IBM:develop' into develop
onkarbhardwaj Jul 11, 2023
f8251b9
fix: fixed download tune assets, add path argument
moneill0 Jul 12, 2023
5c541d1
style: cleaned code and added helper functions
moneill0 Jul 13, 2023
8625fb7
Merge branch 'main' into update_tunes_api
moneill0 Jul 13, 2023
1451c48
update branch with develop
mirianfsilva Jul 13, 2023
ea1190e
fix: removed files added to commit by mistake
moneill0 Jul 14, 2023
58710e7
refactor: handling of default argument and pytest
moneill0 Jul 17, 2023
d40cb0c
fix: fixed download assets pytest
moneill0 Jul 17, 2023
8562c35
fix: fixed request handler
moneill0 Jul 18, 2023
d754bf3
Merge branch 'develop' into update_tunes_api
onkarbhardwaj Jul 27, 2023
0aaa357
Update responses.py
onkarbhardwaj Aug 2, 2023
7c45048
Updates: examples, doc strings and tune manager
mirianfsilva Aug 3, 2023
921b771
Updates: examples, doc strings and tune manager
mirianfsilva Aug 3, 2023
2d5da28
update tests cases, improve tune downloads and tune delete
mirianfsilva Aug 3, 2023
ac04a9a
merge develop
mirianfsilva Aug 3, 2023
243a028
add download option in model class
mirianfsilva Aug 4, 2023
75ed79c
Update schemas
mirianfsilva Aug 7, 2023
0809717
run pre-commit
mirianfsilva Aug 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/user/prompt_tuning/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ def get_creds():
)
time.sleep(5)

print("~~~~~~~ Downloading tuned model assets ~~~~~")
to_download_assets = input("Download tuned model assets? (y/N):\n")
if to_download_assets == "y":
tuned_model.download()

time.sleep(5)

print("~~~~~~~ Deleting a tuned model ~~~~~")
to_delete = input("Delete this model? (y/N):\n")
if to_delete == "y":
Expand Down
9 changes: 9 additions & 0 deletions examples/user/prompt_tuning/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ def get_creds():
tune_get_result,
)

time.sleep(5)

print("~~~~~~~ Downloading tuned model assets ~~~~~")
to_download_assets = input("Download tuned model assets? (y/N):\n")
if to_download_assets == "y":
tuned_model.download()

time.sleep(5)

print("~~~~~~~ Deleting a tuned model ~~~~~")
to_delete = input("Delete this model? (y/N):\n")
if to_delete == "y":
Expand Down
76 changes: 76 additions & 0 deletions examples/user/prompt_tuning/tune_manager_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
from pathlib import Path

from dotenv import load_dotenv

from genai.credentials import Credentials
from genai.schemas.tunes_params import (
CreateTuneHyperParams,
CreateTuneParams,
DownloadAssetsParams,
TunesListParams,
)
from genai.services import FileManager, TuneManager

load_dotenv()
API_KEY = os.getenv("GENAI_KEY", None)
ENDPOINT = os.getenv("GENAI_API", None)

creds = Credentials(api_key=API_KEY, api_endpoint=ENDPOINT)


print("\n=============== UPLOAD LOCAL FILE FOR TUNING ================")
file_path = str(Path(__file__).parent / "file_to_tune.jsonl")
print(" {} \n".format(file_path))

file_uploaded = FileManager.upload_file(credentials=creds, file_path=file_path, purpose="tune")
print("File uploaded: \n", file_uploaded)


print("\n======================== CREATE TUNE ========================")
file_ids = [file_uploaded.id]
hyperparams = CreateTuneHyperParams(verbalizer='classify { "red", "yellow" } Input: {{input}} Output:')

params = CreateTuneParams(
name="Tune Manager Classification",
model_id="google/flan-t5-xl",
method_id="mpt",
task_id="classification",
training_file_ids=file_ids,
parameters=hyperparams,
)

tune_created = TuneManager.create_tune(credentials=creds, params=params)
print("\nNew tune: \n", tune_created)


print("\n======================= GET TUNE BY ID ======================")
t = tune_created.id
tune_get = TuneManager.get_tune(credentials=creds, tune_id=t)
print("\nGet tune result: \n", tune_get)

print("\n======================== LIST TUNES =========================")
list_params = TunesListParams(limit=5, offset=0)
tune_list = TuneManager.list_tunes(credentials=creds, params=list_params)
for t in tune_list.results:
print("\nTune ID:", t.id, ", Tune Name:", t.name)


print("\n================ LIST EXISTING TUNE METHODS ================")
tune_methods = TuneManager.get_tune_methods(credentials=creds)
print("\n Available Tune Methods:\n")
for t in tune_methods.results:
print("Method ID:", t.id, ", Method name:", t.name)


print("\n=================== DOWNLOAD TUNE ASSETS ===================")
# content can be: logs or enconder. The download will only be available when the tune is complete.
tune = tune_list.results[0].id
assets_params = DownloadAssetsParams(id=tune, content="encoder")
tune_assets = TuneManager.download_tune_assets(credentials=creds, params=assets_params)
print("\n Tune assets:", tune_assets)


print("\n======================== DELETE TUNE ========================")
tune_delete = TuneManager.delete_tune(credentials=creds, tune_id=tune)
print("\nDelete tune response: \n", tune_delete)
44 changes: 44 additions & 0 deletions examples/user/prompt_tuning/upload_file_create_tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
from pathlib import Path

from dotenv import load_dotenv

from genai.credentials import Credentials
from genai.schemas.tunes_params import CreateTuneParams
from genai.services import FileManager, TuneManager

load_dotenv()
API_KEY = os.getenv("GENAI_KEY", None)
ENDPOINT = os.getenv("GENAI_API", None)

creds = Credentials(api_key=API_KEY, api_endpoint=ENDPOINT)

# # UPLOAD FILE
# file_path = "<path-to-file>"
print("======================== UPLOAD LOCAL FILE ========================")
file_path = str(Path(__file__).parent / "file_to_tune.jsonl")
print(" {} \n".format(file_path))

upload_file = FileManager.upload_file(credentials=creds, file_path=file_path, purpose="tune")
print("File uploaded: \n", upload_file)

print("\n===================== GET UPLOADED FILE ID ========================")
file_list = FileManager.list_files(credentials=creds)
for f in file_list.results:
if f.file_name == "file_to_tune.jsonl":
file_id = f.id
break
print("Uploaded file has id = {}\n".format(file_id))

print("\n================= CREATE TUNE FOR GENERATION TASK ==================")

tunes_params = CreateTuneParams(
name="flan-t5-xl",
model_id="google/flan-t5-xl",
method_id="pt",
task_id="generation",
training_file_ids=[file_id],
)

tune_create = TuneManager.create_tune(credentials=creds, params=tunes_params)
print("Created tune has the id: \n ", tune_create.id)
7 changes: 7 additions & 0 deletions src/genai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from genai.schemas.tunes_params import (
CreateTuneHyperParams,
CreateTuneParams,
DownloadAssetsParams,
TunesListParams,
)
from genai.services import AsyncResponseGenerator, ServiceInterface
Expand Down Expand Up @@ -367,6 +368,12 @@ def delete(self):
raise GenAiException(ValueError("Tuned model not found. Currently method supports only tuned models."))
TuneManager.delete_tune(service=self.service, tune_id=self.model)

def download(self):
enconder_params = DownloadAssetsParams(id=self.model, content="encoder")
logs_params = DownloadAssetsParams(id=self.model, content="logs")
TuneManager.download_tune_assets(service=self.service, params=enconder_params)
TuneManager.download_tune_assets(service=self.service, params=logs_params)

@staticmethod
def models(credentials: Credentials = None, service: ServiceInterface = None) -> list[ModelCard]:
"""Get a list of models
Expand Down
30 changes: 29 additions & 1 deletion src/genai/routers/tunes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from genai.exceptions import GenAiException
from genai.schemas.tunes_params import CreateTuneParams, TunesListParams
from genai.schemas.tunes_params import (
CreateTuneParams,
DownloadAssetsParams,
TunesListParams,
)
from genai.services.request_handler import RequestHandler
from genai.utils.request_utils import sanitize_params

Expand Down Expand Up @@ -73,3 +77,27 @@ def delete_tune(self, tune_id: str):
return RequestHandler.delete(endpoint, key=self.key, parameters=tune_id)
except Exception as e:
raise GenAiException(e)

def get_tune_methods(self):
"""Get list of tune methods.

Returns:
Any: json with info about the available tune methods.
"""
try:
endpoint = self.service_url + "/tune_methods"
return RequestHandler.get(endpoint, key=self.key)
except Exception as e:
raise GenAiException(e)

def download_tune_assets(self, params: DownloadAssetsParams):
"""Download tune asset.

Returns:
Any: json with info about the downloaded tune asset.
"""
try:
endpoint = self.service_url + TunesRouter.TUNES + "/" + params.id + "/content/" + params.content
return RequestHandler.get(endpoint, key=self.key)
except Exception as e:
raise GenAiException(e)
2 changes: 2 additions & 0 deletions src/genai/schemas/descriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class TunesAPIDescriptions:
STATUS = "Filters the items to be returned based on their status. Possible values are: INITIALIZING, NOT_STARTED, PENDING, HALTED, RUNNING, QUEUED, COMPLETED, FAILED."
INIT_METHOD = "Initialization method to be used. Possible values are RANDOM or TEXT. Defaults to RANDOM. Used only if the method_id is 'pt' = Prompt Tuning."
INIT_TEXT = "Initialization text to be used. This is only applicable if init_method == TEXT. Used only if the method_id is 'pt' = Prompt Tuning."
ID = "The ID of the tune."
CONTENT = "The name of the asset. Available options are encoder and logs."


class FilesAPIDescriptions:
Expand Down
56 changes: 56 additions & 0 deletions src/genai/schemas/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,62 @@ class TuneGetResponse(GenAiResponseModel):
results: Optional[TuneInfoResult]


class TuneMethodsInfo(GenAiResponseModel):
id: str
name: str


class TuneMethodsGetResponse(GenAiResponseModel):
results: Optional[List[TuneMethodsInfo]]


class FileFormatResult(GenAiResponseModel):
id: int
name: str


class FileInfoResult(GenAiResponseModel):
id: str
bytes: str
file_name: str
purpose: str
storage_provider_location: Optional[str]
created_at: datetime
file_formats: List[FileFormatResult]


class TuneParameters(GenAiResponseModel):
accumulate_steps: Optional[int]
batch_size: Optional[int]
learning_rate: Optional[float]
max_input_tokens: Optional[int]
max_output_tokens: Optional[int]
num_epochs: Optional[int]
num_virtual_tokens: Optional[int]
verbalizer: Optional[str]


class TuneInfoResult(GenAiResponseModel):
id: str
name: str
model_id: str
model_name: str
method_id: Optional[str]
method_name: Optional[str]
status: str
task_id: str
task_name: Optional[str]
parameters: Optional[TuneParameters]
created_at: datetime
preferred: Optional[bool]
datapoints: Optional[dict]
validation_files: Optional[list]
training_files: Optional[list]
evaluation_files: Optional[list]
status_message: Optional[str]
started_at: Optional[datetime]


class ModelCard(GenAiResponseModel):
id: Optional[str]
name: Optional[str]
Expand Down
11 changes: 11 additions & 0 deletions src/genai/schemas/tunes_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,14 @@ class Config:
training_file_ids: list[str] = Field(None, description=tx.TRAINING_FILE_IDS)
validation_file_ids: Optional[list[str]] = Field(None, description=tx.VALIDATION_FILE_IDS)
parameters: Optional[CreateTuneHyperParams] = Field(None, description=tx.PARAMETERS)


class DownloadAssetsParams(BaseModel):
"""Class to hold the parameters for downloading tune assets."""

class Config:
anystr_strip_whitespace: True
mirianfsilva marked this conversation as resolved.
Show resolved Hide resolved
# extra: Extra.forbid

id: str = Field(None, description=tx.ID)
content: str = Field(None, description=tx.CONTENT)
Loading