Skip to content

Commit

Permalink
feat: add fine-tuning with deployments (#357)
Browse files Browse the repository at this point in the history
  • Loading branch information
David-Kristek committed May 22, 2024
1 parent e8ecae9 commit 99b77c0
Show file tree
Hide file tree
Showing 29 changed files with 3,739 additions and 639 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
playground/

# Downloaded data for examples
examples/.data
examples/**/.data/

# Pickle files
*.pkl
Expand Down
3 changes: 3 additions & 0 deletions examples/deployment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Deployment
"""
36 changes: 36 additions & 0 deletions examples/deployment/deployed_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Get all deployed models
"""

from pprint import pprint

from dotenv import load_dotenv

from genai.client import Client
from genai.credentials import Credentials

load_dotenv()


def heading(text: str) -> str:
"""Helper function for centering text."""
return "\n" + f" {text} ".center(80, "=") + "\n"


# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
# GENAI_API=<genai-api-endpoint> (optional) DEFAULT_API = "https://bam-api.res.ibm.com"
client = Client(credentials=Credentials.from_env())

print(heading("Get list of deployed models"))
deployment_list = client.deployment.list()
for deployment in deployment_list.results:
pprint(deployment.model_dump())

if len(deployment_list.results) < 1:
print("No deployed models found.")
exit(1)

print(heading("Retrieve information about first deployment"))
deployment_info = client.deployment.retrieve(id=deployment_list.results[0].id)
pprint(deployment_info.model_dump())
152 changes: 152 additions & 0 deletions examples/deployment/deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
Fine tune and deploy custom model
Use custom training data to tune a model for text generation.
Note:
This example has been written to enable an end-user to quickly try fine-tuning. In order to obtain better
performance, a user would need to experiment with the number of observations and tuning hyperparameters
"""

import time
from pathlib import Path

from dotenv import load_dotenv

from genai.client import Client
from genai.credentials import Credentials
from genai.schema import (
DecodingMethod,
DeploymentStatus,
FilePurpose,
TextGenerationParameters,
TuneParameters,
TuneStatus,
)

load_dotenv()
num_training_samples = 50
num_validation_samples = 20
data_root = Path(__file__).parent.resolve() / ".data"
training_file = data_root / "fpb_train.jsonl"
validation_file = data_root / "fpb_validation.jsonl"


def heading(text: str) -> str:
"""Helper function for centering text."""
return "\n" + f" {text} ".center(80, "=") + "\n"


def create_dataset():
Path(data_root).mkdir(parents=True, exist_ok=True)
if training_file.exists():
print("Dataset is already prepared")
return

try:
import pandas as pd
from datasets import load_dataset
except ImportError:
print("Please install datasets and pandas for downloading the dataset.")
raise

data = load_dataset("locuslab/TOFU")
df = pd.DataFrame(data["train"])
df.rename(columns={"question": "input", "answer": "output"}, inplace=True)
df["output"] = df["output"].astype(str)
train_jsonl = df.iloc[:num_training_samples].to_json(orient="records", lines=True, force_ascii=True)
validation_jsonl = df.iloc[-num_validation_samples:].to_json(orient="records", lines=True, force_ascii=True)
with open(training_file, "w") as fout:
fout.write(train_jsonl)
with open(validation_file, "w") as fout:
fout.write(validation_jsonl)


def upload_files(client: Client, update=True):
files_info = client.file.list(search=training_file.name).results
files_info += client.file.list(search=validation_file.name).results

filenames_to_id = {f.file_name: f.id for f in files_info}
for filepath in [training_file, validation_file]:
filename = filepath.name
if filename in filenames_to_id and update:
print(f"File already present: Overwriting {filename}")
client.file.delete(filenames_to_id[filename])
response = client.file.create(file_path=filepath, purpose=FilePurpose.TUNE)
filenames_to_id[filename] = response.result.id
if filename not in filenames_to_id:
print(f"File not present: Uploading {filename}")
response = client.file.create(file_path=filepath, purpose=FilePurpose.TUNE)
filenames_to_id[filename] = response.result.id
return filenames_to_id[training_file.name], filenames_to_id[validation_file.name]


# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
# GENAI_API=<genai-api-endpoint> (optional) DEFAULT_API = "https://bam-api.res.ibm.com"
client = Client(credentials=Credentials.from_env())

print(heading("Creating dataset"))
create_dataset()

print(heading("Uploading files"))
training_file_id, validation_file_id = upload_files(client, update=True)

hyperparams = TuneParameters(
num_epochs=4,
verbalizer="### Input: {{input}} ### Response: {{output}}",
batch_size=4,
learning_rate=0.4,
# Advanced parameters are not defined in the schema
# but can be passed to the API
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
per_device_train_batch_size=4,
num_train_epochs=4,
)
print(heading("Tuning model"))

tune_result = client.tune.create(
model_id="meta-llama/llama-3-8b-instruct",
name="generation-fine-tune-example",
tuning_type="fine_tuning",
task_id="generation",
parameters=hyperparams,
training_file_ids=[training_file_id],
# validation_file_ids=[validation_file_id], # TODO: Broken at the moment - this causes tune to fail
).result

while tune_result.status not in [TuneStatus.FAILED, TuneStatus.HALTED, TuneStatus.COMPLETED]:
new_tune_result = client.tune.retrieve(tune_result.id).result
print(f"Waiting for tune to finish, current status: {tune_result.status}")
tune_result = new_tune_result
time.sleep(10)

if tune_result.status in [TuneStatus.FAILED, TuneStatus.HALTED]:
print("Model tuning failed or halted")
exit(1)

print(heading("Deploying fine-tuned model"))

deployment = client.deployment.create(tune_id=tune_result.id).result

while deployment.status not in [DeploymentStatus.READY, DeploymentStatus.FAILED, DeploymentStatus.EXPIRED]:
deployment = client.deployment.retrieve(id=deployment.id).result
print(f"Waiting for deployment to finish, current status: {deployment.status}")
time.sleep(10)

if deployment.status in [DeploymentStatus.FAILED, DeploymentStatus.EXPIRED]:
print(f"Model deployment failed or expired, status: {deployment.status}")
exit(1)

print(heading("Generate text with fine-tuned model"))
prompt = "What are some books you would reccomend to read?"
print("Prompt: ", prompt)
gen_params = TextGenerationParameters(decoding_method=DecodingMethod.SAMPLE)
gen_response = next(client.text.generation.create(model_id=tune_result.id, inputs=[prompt]))

print("Answer: ", gen_response.results[0].generated_text)

print(heading("Deleting deployment and tuned model"))
client.deployment.delete(id=deployment.id)
client.tune.delete(id=tune_result.id)
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __call__(self, inputs: Documents) -> Embeddings:
for response in self._client.text.embedding.create(
model_id=self._model_id, inputs=inputs, parameters=self._parameters
):
embeddings.extend(response.results)
embedding_list = [result.embedding for result in response.results]
embeddings.extend(embedding_list)

return embeddings

Expand Down
1 change: 1 addition & 0 deletions scripts/docs_examples_generator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class GeneratorConfig(BaseModel):
"examples.text",
"examples.model",
"examples.tune",
"examples.deployment",
"examples.prompt",
"examples.system_prompt",
"examples.file",
Expand Down

0 comments on commit 99b77c0

Please sign in to comment.