Skip to content

Commit

Permalink
add download option in model class
Browse files Browse the repository at this point in the history
Signed-off-by: Mírian Silva <mirianfrsilva@ibm.com>
  • Loading branch information
mirianfsilva committed Aug 4, 2023
1 parent ac04a9a commit 243a028
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
7 changes: 7 additions & 0 deletions examples/user/prompt_tuning/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ def get_creds():
)
time.sleep(5)

print("~~~~~~~ Downloading tuned model assets ~~~~~")
to_download_assets = input("Download tuned model assets? (y/N):\n")
if to_download_assets == "y":
tuned_model.download()

time.sleep(5)

print("~~~~~~~ Deleting a tuned model ~~~~~")
to_delete = input("Delete this model? (y/N):\n")
if to_delete == "y":
Expand Down
9 changes: 9 additions & 0 deletions examples/user/prompt_tuning/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ def get_creds():
tune_get_result,
)

time.sleep(5)

print("~~~~~~~ Downloading tuned model assets ~~~~~")
to_download_assets = input("Download tuned model assets? (y/N):\n")
if to_download_assets == "y":
tuned_model.download()

time.sleep(5)

print("~~~~~~~ Deleting a tuned model ~~~~~")
to_delete = input("Delete this model? (y/N):\n")
if to_delete == "y":
Expand Down
7 changes: 7 additions & 0 deletions src/genai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from genai.schemas.tunes_params import (
CreateTuneHyperParams,
CreateTuneParams,
DownloadAssetsParams,
TunesListParams,
)
from genai.services import AsyncResponseGenerator, ServiceInterface
Expand Down Expand Up @@ -367,6 +368,12 @@ def delete(self):
raise GenAiException(ValueError("Tuned model not found. Currently method supports only tuned models."))
TuneManager.delete_tune(service=self.service, tune_id=self.model)

def download(self):
enconder_params = DownloadAssetsParams(id=self.model, content="encoder")
logs_params = DownloadAssetsParams(id=self.model, content="logs")
TuneManager.download_tune_assets(service=self.service, params=enconder_params)
TuneManager.download_tune_assets(service=self.service, params=logs_params)

@staticmethod
def models(credentials: Credentials = None, service: ServiceInterface = None) -> list[ModelCard]:
"""Get a list of models
Expand Down

0 comments on commit 243a028

Please sign in to comment.