Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add fine-tuning with deployments #357

Merged
merged 6 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
playground/

# Downloaded data for examples
examples/.data
examples/**/.data/

# Pickle files
*.pkl
Expand Down
3 changes: 3 additions & 0 deletions examples/deployment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Deployment
"""
163 changes: 163 additions & 0 deletions examples/deployment/deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""
Fine tune and deploy custom model

Use custom training data to tune a model for text generation.

Note:
This example has been written to enable an end-user to quickly try fine-tuning. In order to obtain better
performance, a user would need to experiment with the number of observations and tuning hyperparameters
"""

import time
from pathlib import Path
from pprint import pprint

from dotenv import load_dotenv

from genai.client import Client
from genai.credentials import Credentials
from genai.schema import (
DecodingMethod,
DeploymentStatus,
FilePurpose,
TextGenerationParameters,
TuneParameters,
TuneStatus,
)

load_dotenv()
num_training_samples = 50
num_validation_samples = 20
data_root = Path(__file__).parent.resolve() / ".data"
training_file = data_root / "fpb_train.jsonl"
validation_file = data_root / "fpb_validation.jsonl"


def heading(text: str) -> str:
"""Helper function for centering text."""
return "\n" + f" {text} ".center(80, "=") + "\n"


def create_dataset():
Path(data_root).mkdir(parents=True, exist_ok=True)
if training_file.exists():
print("Dataset is already prepared")
return

try:
import pandas as pd
from datasets import load_dataset
except ImportError:
print("Please install datasets and pandas for downloading the dataset.")
raise

data = load_dataset("locuslab/TOFU")
df = pd.DataFrame(data["train"])
df.rename(columns={"question": "input", "answer": "output"}, inplace=True)
df["output"] = df["output"].astype(str)
train_jsonl = df.iloc[:num_training_samples].to_json(orient="records", lines=True, force_ascii=True)
validation_jsonl = df.iloc[-num_validation_samples:].to_json(orient="records", lines=True, force_ascii=True)
with open(training_file, "w") as fout:
fout.write(train_jsonl)
with open(validation_file, "w") as fout:
fout.write(validation_jsonl)


def upload_files(client: Client, update=True):
files_info = client.file.list(search=training_file.name).results
files_info += client.file.list(search=validation_file.name).results

filenames_to_id = {f.file_name: f.id for f in files_info}
for filepath in [training_file, validation_file]:
filename = filepath.name
if filename in filenames_to_id and update:
print(f"File already present: Overwriting {filename}")
client.file.delete(filenames_to_id[filename])
response = client.file.create(file_path=filepath, purpose=FilePurpose.TUNE)
filenames_to_id[filename] = response.result.id
if filename not in filenames_to_id:
print(f"File not present: Uploading {filename}")
response = client.file.create(file_path=filepath, purpose=FilePurpose.TUNE)
filenames_to_id[filename] = response.result.id
return filenames_to_id[training_file.name], filenames_to_id[validation_file.name]


client = Client(credentials=Credentials.from_env())

print(heading("Creating dataset"))
create_dataset()

print(heading("Uploading files"))
training_file_id, validation_file_id = upload_files(client, update=True)

hyperparams = TuneParameters(
num_epochs=4,
verbalizer="### Input: {{input}} ### Response: {{output}}",
batch_size=4,
learning_rate=0.4,
# Advanced parameters are not defined in the schema
# but can be passed to the API
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
per_device_train_batch_size=4,
num_train_epochs=4,
)
print(heading("Tuning model"))

tune_result = client.tune.create(
model_id="meta-llama/llama-3-8b-instruct",
name="generation-fine-tune-example",
tuning_type="fine_tuning",
task_id="generation",
parameters=hyperparams,
training_file_ids=[training_file_id],
# validation_file_ids=[validation_file_id], # TODO: Broken at the moment - this causes tune to fail
).result

while tune_result.status not in [TuneStatus.FAILED, TuneStatus.HALTED, TuneStatus.COMPLETED]:
new_tune_result = client.tune.retrieve(tune_result.id).result
print(f"Waiting for tune to finish, current status: {tune_result.status}")
tune_result = new_tune_result
time.sleep(10)

if tune_result.status in [TuneStatus.FAILED, TuneStatus.HALTED]:
print("Model tuning failed or halted")
exit(1)

print("Model tuned successfully")

print(heading("Deploying fine-tuned model"))

deployment = client.deployment.create(tune_id=tune_result.id).result

while deployment.status not in [DeploymentStatus.READY, DeploymentStatus.FAILED, DeploymentStatus.EXPIRED]:
deployment = client.deployment.retrieve(id=deployment.id).result
print(f"Waiting for deployment to finish, current status: {deployment.status}")
time.sleep(10)

if deployment.status in [DeploymentStatus.FAILED, DeploymentStatus.EXPIRED]:
print("Model deployment failed or expired")
exit(1)

print("Model deployed successfully")

print(heading("Generate text with fine-tuned model"))
prompt = "What are some books you would reccomend to read?"
print("Prompt: ", prompt)
gen_params = TextGenerationParameters(decoding_method=DecodingMethod.SAMPLE)
gen_response = next(client.text.generation.create(model_id=tune_result.id, inputs=[prompt]))
print("Answer: ", gen_response.results[0].generated_text)

print(heading("Get list of deployed models"))
deployment_list = client.deployment.list()
for deployment in deployment_list.results:
pprint(deployment.model_dump())

print(heading("Retrieving information about deployment"))
deployment_info = client.deployment.retrieve(id=deployment.id)
pprint(deployment_info.model_dump())

print(heading("Deleting deployment and tuned model"))
client.deployment.delete(id=deployment.id)
client.tune.delete(id=tune_result.id)
print("Deleted")
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __call__(self, inputs: Documents) -> Embeddings:
for response in self._client.text.embedding.create(
model_id=self._model_id, inputs=inputs, parameters=self._parameters
):
embeddings.extend(response.results)
embedding_list = [result.embedding for result in response.results]
embeddings.extend(embedding_list)
Tomas2D marked this conversation as resolved.
Show resolved Hide resolved

return embeddings

Expand Down
1 change: 1 addition & 0 deletions scripts/docs_examples_generator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class GeneratorConfig(BaseModel):
"examples.text",
"examples.model",
"examples.tune",
"examples.deployment",
"examples.prompt",
"examples.system_prompt",
"examples.file",
Expand Down
66 changes: 58 additions & 8 deletions scripts/types_generator/schema_aliases.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ alias:
- _FileCreate_result
- _FileIdRetrieve_result
- _FileIdPatch_result
FileMetadata:
- _BetaEvaluationIdRetrieve_result_file_metadata
- _FileIdPatch_result_metadata
- _BetaEvaluationExperimentIdRetrieve_result_file_metadata
- _BetaEvaluationCreate_result_file_metadata
- _FileCreate_result_metadata
- _BetaEvaluationRetrieve_results_file_metadata
- _FileIdRetrieve_result_metadata
- _FileRetrieve_results_metadata
- _BetaEvaluationExperimentCreate_result_file_metadata
- _BetaEvaluationExperimentRetrieve_results_file_metadata
FileMetadataStats:
- _FileCreate_result_metadata_stats
- _BetaEvaluationCreate_result_file_metadata_stats
- _BetaEvaluationExperimentIdRetrieve_result_file_metadata_stats
- _FileIdRetrieve_result_metadata_stats
- _FileIdPatch_result_metadata_stats
- _BetaEvaluationExperimentRetrieve_results_file_metadata_stats
- _BetaEvaluationIdRetrieve_result_file_metadata_stats
- _FileRetrieve_results_metadata_stats
- _BetaEvaluationRetrieve_results_file_metadata_stats
- _BetaEvaluationExperimentCreate_result_file_metadata_stats
FileListSortBy:
- _FileRetrieveRequestParamsSortBy

Expand Down Expand Up @@ -472,12 +494,36 @@ alias:
- _TuneFromFileCreate_result_evaluation_files
- _TuneIdPatch_result_training_files
- _TuneRetrieve_results_evaluation_files
TuneResultDatapointLoss:
TuneResultContent:
- _TuneRetrieve_results_contents
- _TuneIdPatch_result_contents
- _TuneIdRetrieve_result_contents
- _TuneFromFileCreate_result_contents
- _TuneCreate_result_contents
TuneResultDatapoint:
- _TuneIdRetrieve_result_datapoints
- _TuneRetrieve_results_datapoints
- _TuneCreate_result_datapoints
- _TuneFromFileCreate_result_datapoints
- _TuneIdPatch_result_datapoints
TuneResultDatapointLossData:
- _TuneRetrieve_results_datapoints_loss_data
- _TuneCreate_result_datapoints_loss_data
- _TuneFromFileCreate_result_datapoints_loss_data
- _TuneIdRetrieve_result_datapoints_loss_data
- _TuneFromFileCreate_result_datapoints_loss_data
- _TuneIdPatch_result_datapoints_loss_data
TuneResultDatapointValidationLoss:
- _TuneIdPatch_result_datapoints_validation_loss
- _TuneRetrieve_results_datapoints_validation_loss
- _TuneFromFileCreate_result_datapoints_validation_loss
- _TuneCreate_result_datapoints_validation_loss
- _TuneIdRetrieve_result_datapoints_validation_loss
TuneResultDatapointValidationLossData:
- _TuneRetrieve_results_datapoints_validation_loss_data
- _TuneCreate_result_datapoints_validation_loss_data
- _TuneIdPatch_result_datapoints_validation_loss_data
- _TuneIdRetrieve_result_datapoints_validation_loss_data
- _TuneFromFileCreate_result_datapoints_validation_loss_data
TuneListSortBy:
- _TuneRetrieveRequestParamsSortBy
PromptListSortBy:
Expand All @@ -490,13 +536,8 @@ alias:
- _TuneCreate_result
TuneParameters:
- _TuneCreateRequest_parameters
- _TunePreflightCreateRequest_parameters
# - _TuningTypeRetrieve_results_schema_properties_parameters
TuneResultDatapoints:
- _TuneFromFileCreate_result_datapoints
- _TuneIdPatch_result_datapoints
- _TuneCreate_result_datapoints
- _TuneRetrieve_results_datapoints
- _TuneIdRetrieve_result_datapoints
TextGenerationResult:
- _TextGenerationComparisonCreate_results_result_results
- _TextGenerationCreate_results
Expand Down Expand Up @@ -600,6 +641,15 @@ alias:
- _TextTokenizationCreateRequest_parameters
TextTokenizationReturnOptions:
- _TextTokenizationCreateRequest_parameters_return_options
DeploymentResult:
- _DeploymentCreate201_result
- _DeploymentRetrieve_results
- _DeploymentIdRetrieve_result
DeploymentStatus:
- _DeploymentRetrieve_results_status
- _DeploymentIdRetrieve_result_status
- _DeploymentCreate201_result_status

# SYSTEM PROMPTS -------------------------------------------------------------
SystemPrompt:
- _SystemPromptCreate_result
Expand Down
3 changes: 3 additions & 0 deletions src/genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
BaseServiceServices,
)
from genai.credentials import Credentials
from genai.deployment import DeploymentService as _DeploymentService
from genai.file import FileService as _FileService
from genai.folder import FolderService as _FolderService
from genai.model import ModelService as _ModelService
Expand Down Expand Up @@ -41,6 +42,7 @@ class BaseServices(BaseServiceServices):
TagService: type[_TagService] = _TagService
FolderService: type[_FolderService] = _FolderService
TaskService: type[_TaskService] = _TaskService
DeploymentService: type[_DeploymentService] = _DeploymentService


class BaseConfig(BaseServiceConfig):
Expand Down Expand Up @@ -144,3 +146,4 @@ def __init__(
self.tag = services.TagService(api_client=api_client)
self.folder = services.FolderService(api_client=api_client)
self.task = services.TaskService(api_client=api_client)
self.deployment = services.DeploymentService(api_client=api_client)
3 changes: 3 additions & 0 deletions src/genai/deployment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Functionalities related to deployment"""

from genai.deployment.deployment_service import *
Loading