# Prompt tuning example of 'google/flan-t5-xl' model


## Library imports


In [115]:
import os
from os import path
import time
from pathlib import Path


## Dynamicly import 'datasets' library


In [116]:
try:
    from datasets import load_dataset
except ImportError:
    print("Install datasets: it is a pre-requisite to run this example")
    

## Dynamicly import of 'pandas' library


In [117]:
try:
    import pandas as pd
except ImportError:
    print("Install pandas: it is a pre-requisite to run this example")
    

## Import of GenAI libraries


In [92]:
from genai.credentials import Credentials # To load API key
from genai.model import Model # To create a model
from genai.schemas.generate_params import GenerateParams # To create the model params
from genai.schemas.tunes_params import CreateTuneHyperParams, TunesListParams # To create params for tuning
from genai.services import FileManager # To upload data files for tune to workbench
from genai.services.tune_manager import TuneManager # This is the class to invoke tune methods


## Load API credentials


In [118]:
# Dotenv is a zero-dependency module that loads environment variables from a .env file into process.
from dotenv import load_dotenv
load_dotenv()


True

## Declare filepaths and other variables


In [119]:
current_path = os.getcwd()
data_root = os.path.join(current_path,"data")
training_file = os.path.join(data_root,"fpb_train.jsonl")
validation_file = os.path.join(data_root,"fpb_validation.jsonl")

num_training_samples = 100
num_validation_samples = 20


# Declare a function to create a dataset


In [120]:
def create_dataset():
    Path(data_root).mkdir(parents=True, exist_ok=True)
#     if training_file.exists() and validation_file.exists():
    if path.exists(training_file) and path.exists(validation_file):
        return
    data = load_dataset(
        "financial_phrasebank",
        "sentences_allagree",
    )
    df = pd.DataFrame(data["train"]).sample(n=num_training_samples + num_validation_samples)
    df.rename(columns={"sentence": "input", "label": "output"}, inplace=True)
    df["output"] = df["output"].astype(str)
    train_jsonl = df.iloc[:num_training_samples].to_json(orient="records", lines=True, force_ascii=True)
    validation_jsonl = df.iloc[num_training_samples:].to_json(orient="records", lines=True, force_ascii=True)
    with open(training_file, "w") as fout:
        fout.write(train_jsonl)
    with open(validation_file, "w") as fout:
        fout.write(validation_jsonl)
        

In [48]:
# Test
# create_dataset()

# Declare a function to upload the tuning files


In [122]:
def upload_files(creds, update=True):
    fileinfos = FileManager.list_files(credentials=creds).results
    filenames_to_id = {f.file_name: f.id for f in fileinfos}
    for filepath in [training_file, validation_file]:
        filename = Path(filepath).name
        if update and filename in filenames_to_id:
            print(f"File already present: Overwriting {filename}")
            FileManager.delete_file(credentials=creds, file_id=filenames_to_id[filename])
            FileManager.upload_file(credentials=creds, file_path=str(filepath), purpose="tune")
        if filename not in filenames_to_id:
            print(f"File not present: Uploading {filename}")
            FileManager.upload_file(credentials=creds, file_path=str(filepath), purpose="tune")
            

In [52]:
# creds = get_creds()
# create_dataset()
# upload_files(creds, update=True)

File already present: Overwriting fpb_train.jsonl
File already present: Overwriting fpb_validation.jsonl


# Declare a function to get all training and validation files id


In [123]:
def get_file_ids(creds):
    fileinfos = FileManager.list_files(credentials=creds).results
    training_file_ids = [f.id for f in fileinfos if f.file_name == Path(training_file).name]
    validation_file_ids = [f.id for f in fileinfos if f.file_name == Path(validation_file).name]
    return training_file_ids, validation_file_ids


In [57]:
# Test
# creds = get_creds()
# get_file_ids(creds)

(['9d14c064-f76a-4522-a17c-eb1fb8a85fd8'],
 ['47500724-e9f6-4190-a339-ecee34c62b50'])

# Declare a function to get your credentials


In [99]:
def get_creds():
    api_key = os.getenv("GENAI_KEY", None)
    endpoint = os.getenv("GENAI_API", None)
    creds = Credentials(api_key=api_key, api_endpoint=endpoint)
    return creds


# MAIN PROGRAM


## Create tuning dataset and upload it to server


In [124]:
creds = get_creds()
create_dataset()
upload_files(creds, update=True)


File already present: Overwriting fpb_train.jsonl
File already present: Overwriting fpb_validation.jsonl


## Create an instance of a model, create tune parameters, and get the ids of training and validation files uploaded

In [126]:
model = Model("google/flan-t5-xl", params=None, credentials=creds)
# Task: classification
hyperparams = CreateTuneHyperParams(num_epochs=2, verbalizer='classify { "0", "1", "2" } Input: {{input}} Output:')
training_file_ids, validation_file_ids = get_file_ids(creds)


## Tune the model

In [127]:
# We run the tunning process and assign it to an object to check the status
tuned_model = model.tune(
    name="classification-mpt-tune-api",
    method="mpt",
    task="classification",
    hyperparameters=hyperparams,
    training_file_ids=training_file_ids,
    validation_file_ids=validation_file_ids,
)

In [128]:
status = tuned_model.status()
while status not in ["FAILED", "HALTED", "COMPLETED"]:
    print(status)
    time.sleep(20)
    status = tuned_model.status()
    

RUNNING
RUNNING
RUNNING
RUNNING
RUNNING
RUNNING
RUNNING
RUNNING


In [130]:
print(tuned_model.status()) # Run to check final status


COMPLETED



## Test the tuned model


In [131]:
prompt = "Hi, how are you? I'm doing well"
genparams = GenerateParams(
    decoding_method="greedy",
    max_new_tokens=50,
    min_new_tokens=1,
)

print("Answer = ", tuned_model.generate([prompt])[0].generated_text)


Answer =  1


In [132]:
greeting = "Hello! How are you?"
lots_of_greetings = [greeting] * 10
num_of_greeting = 0
for result in tuned_model.generate(lots_of_greetings):
    num_of_greeting += 1
    print(f"\t {num_of_greeting}-{result.generated_text}")
    

	 1-2
	 2-1
	 3-0
	 4-0
	 5-1
	 6-1
	 7-1
	 8-1
	 9-2
	 10-2


## Listing tunes and getting tune metadata with TuneManager


In [133]:
list_params = TunesListParams(limit=5, offset=0)

tune_list = TuneManager.list_tunes(credentials=creds, params=list_params)
print("\n\nList of tunes: \n\n")
for tune in tune_list.results:
    print(tune, "\n")
    



List of tunes: 


id='flan-t5-xl-mpt-kGGKGEyS-2023-07-13-17-44-07' name='classification-mpt-tune-api' model_id='google/flan-t5-xl' model_name='flan-t5-xl (3B)' method_id='mpt' method_name='Multitask Prompt Tuning' status='COMPLETED' task_id='classification' task_name='Classification' parameters=TuneParameters(accumulate_steps=16, batch_size=16, learning_rate=0.3, max_input_tokens=256, max_output_tokens=128, num_epochs=2, num_virtual_tokens=100, verbalizer='classify { "0", "1", "2" } Input: {{input}} Output:') created_at=datetime.datetime(2023, 7, 13, 17, 44, 7, tzinfo=datetime.timezone.utc) preferred=True datapoints=None validation_files=None training_files=None evaluation_files=None status_message=None started_at=datetime.datetime(2023, 7, 13, 17, 44, 8, tzinfo=datetime.timezone.utc) finished_at='2023-07-13T17:46:40.000Z' 

id='flan-t5-xl-mpt-le7Yfp0u-2023-07-13-02-36-57' name='classification-mpt-tune-api' model_id='google/flan-t5-xl' model_name='flan-t5-xl (3B)' method_id='mpt' met

In [134]:
tune_get_result = TuneManager.get_tune(credentials=creds, tune_id=tuned_model.model)
print(
    "\n\n~~~~~ Metadata for a single tune with TuneManager ~~~~: \n\n",
    tune_get_result,
)




~~~~~ Metadata for a single tune with TuneManager ~~~~: 

 id='flan-t5-xl-mpt-kGGKGEyS-2023-07-13-17-44-07' name='classification-mpt-tune-api' model_id='google/flan-t5-xl' model_name='flan-t5-xl (3B)' method_id='mpt' method_name='Multitask Prompt Tuning' status='COMPLETED' task_id='classification' task_name='Classification' parameters=TuneParameters(accumulate_steps=16, batch_size=16, learning_rate=0.3, max_input_tokens=256, max_output_tokens=128, num_epochs=2, num_virtual_tokens=100, verbalizer='classify { "0", "1", "2" } Input: {{input}} Output:') created_at=datetime.datetime(2023, 7, 13, 17, 44, 7, tzinfo=datetime.timezone.utc) preferred=True datapoints={'loss': [{'data': {'epoch': 0, 'value': 0.8125, 'timestamp': '2023-07-13T17:46:26.409898'}, 'timestamp': '2023-07-13T17:46:26.409Z'}, {'data': {'epoch': 1, 'value': 0.67578125, 'timestamp': '2023-07-13T17:46:31.066027'}, 'timestamp': '2023-07-13T17:46:31.066Z'}]} validation_files=[{'id': '8c181aad-2058-4d6f-8e7d-f4ecbdbbe1a4', 'fi

## Deleting a tuned model


In [111]:
to_delete = "y"
if to_delete == "y":
    tuned_model.delete()
    