Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add fine-tuning with deployments #357

Merged
merged 6 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
playground/

# Downloaded data for examples
examples/.data
examples/**/.data/

# Pickle files
*.pkl
Expand Down
3 changes: 3 additions & 0 deletions examples/deployment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Deployment
"""
36 changes: 36 additions & 0 deletions examples/deployment/deployed_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Get all deployed models
"""

from pprint import pprint

from dotenv import load_dotenv

from genai.client import Client
from genai.credentials import Credentials

load_dotenv()


def heading(text: str) -> str:
"""Helper function for centering text."""
return "\n" + f" {text} ".center(80, "=") + "\n"


# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
# GENAI_API=<genai-api-endpoint> (optional) DEFAULT_API = "https://bam-api.res.ibm.com"
client = Client(credentials=Credentials.from_env())

print(heading("Get list of deployed models"))
deployment_list = client.deployment.list()
for deployment in deployment_list.results:
pprint(deployment.model_dump())

if len(deployment_list.results) < 1:
print("No deployed models found.")
exit(1)

print(heading("Retrieve information about first deployment"))
deployment_info = client.deployment.retrieve(id=deployment_list.results[0].id)
pprint(deployment_info.model_dump())
152 changes: 152 additions & 0 deletions examples/deployment/deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
Fine tune and deploy custom model

Use custom training data to tune a model for text generation.

Note:
This example has been written to enable an end-user to quickly try fine-tuning. In order to obtain better
performance, a user would need to experiment with the number of observations and tuning hyperparameters
"""

import time
from pathlib import Path

from dotenv import load_dotenv

from genai.client import Client
from genai.credentials import Credentials
from genai.schema import (
DecodingMethod,
DeploymentStatus,
FilePurpose,
TextGenerationParameters,
TuneParameters,
TuneStatus,
)

load_dotenv()
num_training_samples = 50
num_validation_samples = 20
data_root = Path(__file__).parent.resolve() / ".data"
training_file = data_root / "fpb_train.jsonl"
validation_file = data_root / "fpb_validation.jsonl"


def heading(text: str) -> str:
"""Helper function for centering text."""
return "\n" + f" {text} ".center(80, "=") + "\n"


def create_dataset():
Path(data_root).mkdir(parents=True, exist_ok=True)
if training_file.exists():
print("Dataset is already prepared")
return

try:
import pandas as pd
from datasets import load_dataset
except ImportError:
print("Please install datasets and pandas for downloading the dataset.")
raise

data = load_dataset("locuslab/TOFU")
df = pd.DataFrame(data["train"])
df.rename(columns={"question": "input", "answer": "output"}, inplace=True)
df["output"] = df["output"].astype(str)
train_jsonl = df.iloc[:num_training_samples].to_json(orient="records", lines=True, force_ascii=True)
validation_jsonl = df.iloc[-num_validation_samples:].to_json(orient="records", lines=True, force_ascii=True)
with open(training_file, "w") as fout:
fout.write(train_jsonl)
with open(validation_file, "w") as fout:
fout.write(validation_jsonl)


def upload_files(client: Client, update=True):
files_info = client.file.list(search=training_file.name).results
files_info += client.file.list(search=validation_file.name).results

filenames_to_id = {f.file_name: f.id for f in files_info}
for filepath in [training_file, validation_file]:
filename = filepath.name
if filename in filenames_to_id and update:
print(f"File already present: Overwriting {filename}")
client.file.delete(filenames_to_id[filename])
response = client.file.create(file_path=filepath, purpose=FilePurpose.TUNE)
filenames_to_id[filename] = response.result.id
if filename not in filenames_to_id:
print(f"File not present: Uploading {filename}")
response = client.file.create(file_path=filepath, purpose=FilePurpose.TUNE)
filenames_to_id[filename] = response.result.id
return filenames_to_id[training_file.name], filenames_to_id[validation_file.name]


# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
# GENAI_API=<genai-api-endpoint> (optional) DEFAULT_API = "https://bam-api.res.ibm.com"
client = Client(credentials=Credentials.from_env())

print(heading("Creating dataset"))
create_dataset()

print(heading("Uploading files"))
training_file_id, validation_file_id = upload_files(client, update=True)

hyperparams = TuneParameters(
num_epochs=4,
verbalizer="### Input: {{input}} ### Response: {{output}}",
batch_size=4,
learning_rate=0.4,
# Advanced parameters are not defined in the schema
# but can be passed to the API
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
per_device_train_batch_size=4,
num_train_epochs=4,
)
print(heading("Tuning model"))

tune_result = client.tune.create(
model_id="meta-llama/llama-3-8b-instruct",
name="generation-fine-tune-example",
tuning_type="fine_tuning",
task_id="generation",
parameters=hyperparams,
training_file_ids=[training_file_id],
# validation_file_ids=[validation_file_id], # TODO: Broken at the moment - this causes tune to fail
).result

while tune_result.status not in [TuneStatus.FAILED, TuneStatus.HALTED, TuneStatus.COMPLETED]:
new_tune_result = client.tune.retrieve(tune_result.id).result
print(f"Waiting for tune to finish, current status: {tune_result.status}")
tune_result = new_tune_result
time.sleep(10)

if tune_result.status in [TuneStatus.FAILED, TuneStatus.HALTED]:
print("Model tuning failed or halted")
exit(1)

print(heading("Deploying fine-tuned model"))

deployment = client.deployment.create(tune_id=tune_result.id).result

while deployment.status not in [DeploymentStatus.READY, DeploymentStatus.FAILED, DeploymentStatus.EXPIRED]:
deployment = client.deployment.retrieve(id=deployment.id).result
print(f"Waiting for deployment to finish, current status: {deployment.status}")
time.sleep(10)

if deployment.status in [DeploymentStatus.FAILED, DeploymentStatus.EXPIRED]:
print(f"Model deployment failed or expired, status: {deployment.status}")
exit(1)

print(heading("Generate text with fine-tuned model"))
prompt = "What are some books you would reccomend to read?"
print("Prompt: ", prompt)
gen_params = TextGenerationParameters(decoding_method=DecodingMethod.SAMPLE)
gen_response = next(client.text.generation.create(model_id=tune_result.id, inputs=[prompt]))

print("Answer: ", gen_response.results[0].generated_text)

print(heading("Deleting deployment and tuned model"))
client.deployment.delete(id=deployment.id)
client.tune.delete(id=tune_result.id)
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __call__(self, inputs: Documents) -> Embeddings:
for response in self._client.text.embedding.create(
model_id=self._model_id, inputs=inputs, parameters=self._parameters
):
embeddings.extend(response.results)
embedding_list = [result.embedding for result in response.results]
embeddings.extend(embedding_list)
Tomas2D marked this conversation as resolved.
Show resolved Hide resolved

return embeddings

Expand Down
1 change: 1 addition & 0 deletions scripts/docs_examples_generator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class GeneratorConfig(BaseModel):
"examples.text",
"examples.model",
"examples.tune",
"examples.deployment",
"examples.prompt",
"examples.system_prompt",
"examples.file",
Expand Down
Loading