From 9488c3d907be045dc479cdad1b6c25d5a9f4e4c8 Mon Sep 17 00:00:00 2001 From: Oleg Kostromin Date: Wed, 23 Aug 2023 20:54:58 +0200 Subject: [PATCH] gpt_tuning --- README.md | 65 ++++++++++++- skllm/models/gpt/__init__.py | 2 + skllm/models/gpt/gpt.py | 174 +++++++++++++++++++++++++++++++++++ skllm/openai/tuning.py | 67 ++++++++++++++ 4 files changed, 306 insertions(+), 2 deletions(-) create mode 100644 skllm/models/gpt/gpt.py create mode 100644 skllm/openai/tuning.py diff --git a/README.md b/README.md index cb0c38f..2ce61c1 100644 --- a/README.md +++ b/README.md @@ -202,7 +202,7 @@ Note: as the model is not being re-trained, but uses the training data during in ### Dynamic Few-Shot Text Classification -*To use this feature, you need to install `annoy` library:* +_To use this feature, you need to install `annoy` library:_ ```bash pip install scikit-llm[annoy] @@ -210,7 +210,7 @@ pip install scikit-llm[annoy] `DynamicFewShotGPTClassifier` dynamically selects N samples per class to include in the prompt. This allows the few-shot classifier to scale to datasets that are too large for the standard context window of LLMs. -*How does it work?* +_How does it work?_ During fitting, the whole dataset is partitioned by class, vectorized, and stored. @@ -280,6 +280,67 @@ clf.fit(X_train, y_train_encoded) yh = clf.predict(X_test) ``` +### LLM Fine-Tuning + +At the moment the following scenarios are supported for tuning: + +- **Text classification**: the model is fine-tuned to predict a single label per sample. The following estimators are supported: + - `skllm.models.palm.PaLMClassifier` + - `skllm.models.gpt.GPTClassifier` +- **Text to text**: the model is fine-tuned on arbitrary text input-output pairs. The following estimators are supported: + - `skllm.models.palm.PaLM` + - `skllm.models.gpt.GPT` + +Example 1: Fine-tuning a PaLM model for text classification + +```python +from skllm.models.palm import PaLMClassifier +clf = PaLMClassifier(n_update_steps=100) +clf.fit(X_train, y_train) # y_train is a list of labels +labels = clf.predict(X_test) +``` + +Example 2: Fine-tuning a PaLM model for text to text tasks + +```python +from skllm.models.palm import PaLM +clf = PaLM(n_update_steps=100) +clf.fit(X_train, y_train) # y_train is any desired output text +labels = clf.predict(X_test) +``` + +_Note:_ PaLM models tuning requires a Vertex AI account. Please refer to our [official guide on Medium](https://medium.com/@iryna230520/fine-tune-google-palm-2-with-scikit-llm-d41b0aa673a5) for more details. + +Example 3: Fine-tuning a GPT model for text classification + +```python +from skllm.models.gpt import GPTClassifier + +clf = GPTClassifier( + base_model = "gpt-3.5-turbo-0613", + n_epochs = None, # int or None. When None, will be determined automatically by OpenAI + default_label = "Random", # optional +) + +clf.fit(X_train, y_train) # y_train is a list of labels +labels = clf.predict(X_test) +``` + +Example 4: Fine-tuning a GPT model for text to text tasks + +```python +from skllm.models.gpt import GPTC + +clf = GPT( + base_model = "gpt-3.5-turbo-0613", + n_epochs = None, # int or None. When None, will be determined automatically by OpenAI + system_msg = "You are a text processing model." +) + +clf.fit(X_train, y_train) # y_train is any desired output text +labels = clf.predict(X_test) +``` + ### Text Summarization GPT excels at performing summarization tasks. Therefore, we provide `GPTSummarizer` that can be used both as stand-alone estimator, or as a preprocessor (in this case we can make an analogy with a dimensionality reduction preprocessor). diff --git a/skllm/models/gpt/__init__.py b/skllm/models/gpt/__init__.py index 4f72fb0..142484d 100644 --- a/skllm/models/gpt/__init__.py +++ b/skllm/models/gpt/__init__.py @@ -4,3 +4,5 @@ ZeroShotGPTClassifier, MultiLabelZeroShotGPTClassifier, ) + +from skllm.models.gpt.gpt import GPTClassifier, GPT \ No newline at end of file diff --git a/skllm/models/gpt/gpt.py b/skllm/models/gpt/gpt.py new file mode 100644 index 0000000..cbaab53 --- /dev/null +++ b/skllm/models/gpt/gpt.py @@ -0,0 +1,174 @@ +from typing import Optional, Union, List +import pandas as pd +from skllm.models._base import _BaseZeroShotGPTClassifier +from skllm.prompts.builders import build_zero_shot_prompt_slc +from skllm.openai.credentials import set_credentials +from skllm.openai.tuning import create_tuning_job, await_results, delete_file +import numpy as np +import json +import uuid + + +def _build_clf_example( + x: str, y: str, system_msg="You are a text classification model." +): + sample = { + "messages": [ + {"role": "system", "content": system_msg}, + {"role": "user", "content": x}, + {"role": "assistant", "content": y}, + ] + } + return json.dumps(sample) + + +class _Tunable: + system_msg = "You are a text classification model." + + def _build_label(self, label: str): + return json.dumps({"label": label}) + + def _tune(self, X, y): + file_uuid = str(uuid.uuid4()) + filename = f"skllm_{file_uuid}.jsonl" + with open(filename, "w+") as f: + for xi, yi in zip(X, y): + f.write( + _build_clf_example( + self._get_prompt(xi), self._build_label(yi), self.system_msg + ) + ) + f.write("\n") + set_credentials(self._get_openai_key(), self._get_openai_org()) + job = create_tuning_job( + self.base_model, + filename, + self.n_epochs, + self.custom_suffix, + ) + print(f"Created new tuning job. JOB_ID = {job['id']}") + job = await_results(job["id"]) + self.openai_model = job["fine_tuned_model"] + delete_file(job["training_file"]) + print(f"Finished training. Number of trained tokens: {job['trained_tokens']}.") + + +class GPTClassifier(_BaseZeroShotGPTClassifier, _Tunable): + """Fine-tunable GPT classifier for single-label classification.""" + + supported_models = ["gpt-3.5-turbo-0613"] + + def __init__( + self, + base_model: str = "gpt-3.5-turbo-0613", + default_label: Optional[str] = "Random", + openai_key: Optional[str] = None, + openai_org: Optional[str] = None, + n_epochs: Optional[int] = None, + custom_suffix: Optional[str] = "skllm", + ): + self.base_model = base_model + self.n_epochs = n_epochs + self.custom_suffix = custom_suffix + if base_model not in self.supported_models: + raise ValueError( + f"Model {base_model} is not supported. Supported models are" + f" {self.supported_models}" + ) + super().__init__( + openai_model="undefined", + default_label=default_label, + openai_key=openai_key, + openai_org=openai_org, + ) + + def _get_prompt(self, x: str) -> str: + return build_zero_shot_prompt_slc(x, repr(self.classes_)) + + def fit( + self, + X: Union[np.ndarray, pd.Series, List[str]], + y: Union[np.ndarray, pd.Series, List[str]], + ): + """Fits the model to the given data. + + Parameters + ---------- + X : Union[np.ndarray, pd.Series, List[str]] + training data + y : Union[np.ndarray, pd.Series, List[str]] + training labels + + Returns + ------- + GPTClassifier + self + """ + X = self._to_np(X) + y = self._to_np(y) + super().fit(X, y) + self._tune(X, y) + return self + + +# similarly to PaLM, this is not a classifier, but a quick way to re-use the code +# the hierarchy of classes will be reworked in the next releases +class GPT(_BaseZeroShotGPTClassifier, _Tunable): + """Fine-tunable GPT on arbitrary input-output pairs.""" + + supported_models = ["gpt-3.5-turbo-0613"] + + def __init__( + self, + base_model: str = "gpt-3.5-turbo-0613", + openai_key: Optional[str] = None, + openai_org: Optional[str] = None, + n_epochs: Optional[int] = None, + custom_suffix: Optional[str] = "skllm", + system_msg: Optional[str] = "You are a text processing model.", + ): + self.base_model = base_model + self.n_epochs = n_epochs + self.custom_suffix = custom_suffix + self.system_msg = system_msg + if base_model not in self.supported_models: + raise ValueError( + f"Model {base_model} is not supported. Supported models are" + f" {self.supported_models}" + ) + super().__init__( + openai_model="undefined", # this will be rewritten later + default_label="Random", # just for compatibility + openai_key=openai_key, + openai_org=openai_org, + ) + + def _get_prompt(self, x: str) -> str: + return x + + def _build_label(self, label: str): + return label + + def fit( + self, + X: Union[np.ndarray, pd.Series, List[str]], + y: Union[np.ndarray, pd.Series, List[str]], + ): + """Fits the model to the given data. + + Parameters + ---------- + X : Union[np.ndarray, pd.Series, List[str]] + training data + y : Union[np.ndarray, pd.Series, List[str]] + training labels + + Returns + ------- + GPT + self + """ + X = self._to_np(X) + y = self._to_np(y) + self._tune(X, y) + return self diff --git a/skllm/openai/tuning.py b/skllm/openai/tuning.py new file mode 100644 index 0000000..3d6cb46 --- /dev/null +++ b/skllm/openai/tuning.py @@ -0,0 +1,67 @@ +from typing import Optional +import openai +from time import sleep +from datetime import datetime +import os + + +def create_tuning_job( + model: str, + training_file: str, + n_epochs: Optional[str] = None, + suffix: Optional[str] = None, +): + out = openai.File.create(file=open(training_file, "rb"), purpose="fine-tune") + print(f"Created new file. FILE_ID = {out['id']}") + print(f"Waiting for file to be processed ...") + while not wait_file_ready(out["id"]): + sleep(5) + # delete the training_file after it is uploaded + os.remove(training_file) + params = { + "model": model, + "training_file": out["id"], + } + if n_epochs is not None: + params["hyperparameters"] = {"n_epochs": n_epochs} + if suffix is not None: + params["suffix"] = suffix + return openai.FineTuningJob.create(**params) + + +def await_results(job_id: str, check_interval: int = 120): + while True: + job = openai.FineTuningJob.retrieve(job_id) + status = job["status"] + if status == "succeeded": + return job + elif status == "failed" or status == "cancelled": + print(job) + raise RuntimeError(f"Tuning job failed with status {status}") + else: + now = datetime.now() + print( + f"[{now}] Waiting for tuning job to complete. Current status: {status}" + ) + sleep(check_interval) + +def delete_file(file_id:str): + openai.File.delete(file_id) + +def wait_file_ready(file_id): + files = openai.File.list()["data"] + found = False + for file in files: + if file["id"] == file_id: + found = True + if file["status"] == "processed": + return True + elif file["status"] in ["error", "deleting", "deleted"]: + print(file) + raise RuntimeError( + f"File upload {file_id} failed with status {file['status']}" + ) + else: + return False + if not found: + raise RuntimeError(f"File {file_id} not found")