# Finetuning on Tune Studio using `tuneapi`

In this example we will create a list of threads (chats) that we want to use as our training sample. Then upload it to Tune Studio and use it to train a `meta/llama-3-8b-instruct-8k` model.

All using `tuneapi` package so you can use it in your code too.

In [1]:
import os
from tuneapi import types as tt
from tuneapi import endpoints as te

# 🤗 datasets
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# this is a sample pseudo-ranker dataset I built from `nvidia/HelpSteer2` to fit Tune platform
ds = load_dataset("yashnbx/hs2-tune")["train"]
print(ds[0])

{'id': 'JbWB9D',
 'conversations': [{'from': 'system',
   'value': 'You are "Ranker Assistant" whose job is to rated each response on a Likert 5 scale (between 0 and 4) for each of these attributes:\n\n1.  Helpfulness: Overall helpfulness of the response to the prompt.\n2.  Correctness: Inclusion of all pertinent facts without errors.\n3.  Coherence: Consistency and clarity of expression.\n4.  Complexity: Intellectual depth required to write response (i.e. whether the response can be written by anyone with basic\n    language competency or requires deep domain expertise).\n5.  Verbosity: Amount of detail included in the response, relative to what is asked for in the prompt.\n\n\nSome prompt can be multi-turn. In this case, the prompt consists of all of the user turns and all but the last assistant\nturn, which is contained in the response field. This is done because the attribute values only assessed only for the last\nassistant turn. For multi-turn prompts, the structure of prompts lo

In [3]:
# build a threads list from `yashnbx/hs2-tune` dataset
if not os.path.exists("./hs2-cookbook"):
    ds = load_dataset("yashnbx/hs2-tune")["train"]
    threads = tt.ThreadsList()
    for i in range(200):
        threads.append(tt.Thread.from_dict(ds[i]))
    print(threads)
    threads.to_disk("./hs2-cookbook")
else:
    # already save, just deserialize
    threads = tt.ThreadsList.from_disk("./hs2-cookbook")

In [4]:
ft = te.FinetuningAPI(
    tune_api_key="xxxxx",
    tune_org_id="yyyyy", # ignore this if you don't have multiple orgs
)
out = ft.upload_dataset(
    threads=threads,
    name="hs2-test",
    override=True
)
print(out)

[2024-08-11T16:27:54+0530] [INFO] [finetune.py:66] Dataset saved to tuneds/hs2-test/tuneds.jsonl
[2024-08-11T16:27:55+0530] [INFO] [finetune.py:92] Upload successful!


FTDataset(path='datasets/chat/hs2-test.jsonl', type='relic')


In [5]:
# You can finetune on multiple datasets at once so pass them as a list, this function will return a model object
# once the model training is complete, this will take some time
model = ft.finetune("ranker2-demo", datasets=[out])
model

[2024-08-11T15:51:29+0530] [INFO] [finetune.py:131] Finetuning job created with ID: azxrbqsj. Check progress at: https://studio.tune.app/finetuning/azxrbqsj?org_id=e3365ae7-ceeb-425b-b983-1703e8456f76


<TuneModel: ranker2-demo-model-azxrbqsj | e3365ae7-ceeb-425b-b983-1703e8456f76>

In [6]:
# to get the job status you can keep calling this function
ft.get_job("azxrbqsj")

{'id': 'azxrbqsj',
 'name': 'ranker2-demo',
 'resource': {'gpu': 'nvidia-l4',
  'gpuCount': '1',
  'diskSize': '30Gi',
  'maxRetries': 1},
 'meta': {'metadata': {'base_model_id': '5fmycsn2',
   'modality': 1,
   'training_config': {'adapter': 'qlora',
    'base_model': 'meta-llama/Meta-Llama-3-8B-Instruct',
    'base_model_config': 'meta-llama/Meta-Llama-3-8B-Instruct',
    'chat_template': 'llama3',
    'datasets': [{'conversation': 'llama3',
      'data_files': '/root/.cache/model/chat/hs2-test.jsonl',
      'ds_type': 'json',
      'path': '/root/.cache/model/chat/hs2-test.jsonl',
      'type': 'sharegpt'}],
    'eval_sample_packing': False,
    'eval_steps': 50,
    'flash_attention': True,
    'gradient_accumulation_steps': 4,
    'gradient_checkpointing': True,
    'hf_use_auth_token': True,
    'learning_rate': 0.0001,
    'load_in_4bit': True,
    'logging_steps': 1,
    'lora_alpha': 16,
    'lora_dropout': 0.05,
    'lora_r': 32,
    'lora_target_linear': True,
    'lr_schedu