In [61]:
%load_ext autoreload
%autoreload 2

from loguru import logger
import os
from mistralai.client import MistralClient
from mistralai.models.jobs import WandbIntegrationIn, TrainingParameters

from mistral_fine_tuning.utils import read_fine_tuning_file
from mistral_fine_tuning.reformat import reformat_jsonl

from dotenv import load_dotenv
load_dotenv()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


True

In [46]:
df = read_fine_tuning_file('../data/interim/jokes_fine_tuning.jsonl')

[32m2024-06-26 22:12:37.772[0m | [31m[1mERROR   [0m | [36mmistral_fine_tuning.utils[0m:[36mprocess_keywords[0m:[36m57[0m - [31m[1mError processing keywords: invalid syntax. Perhaps you forgot a comma? (<string>, line 1)[0m


In [47]:
df.head(20)

Unnamed: 0,text,keywords
0,"Oye, fjate que llega un indio al mdico y qu pa...","['indio', 'mdico', 'Toro Sentado', 'enfermo', ..."
1,y me dio qu le pas vena una delegacin de turno...,"['delegacin de turnos', 'Viña del Mar', 'bus d..."
2,Entonces el gua turstico si ustedes miran a la...,"['gua turstico', 'izquierda', 'derecha', 'quin..."
3,y llega el tema del medio que le dice doctor m...,"['luna de viernes', 'seora', 'tres pechos', 'd..."
4,conversando amigo y uno le dice la poblacin un...,"['Viagra', 'robo', 'medicamento', 'polica', 'h..."
5,"Oye, ¿sabes que hay un logo haciendo ese lindo...","['logo', 'Lola', 'discoteca', 'Hola', 'llamarse']"
6,"Deseas tomar algo?, bueno le dijo, pídeme un t...","['tomar algo', 'trago fuerte']"
7,"Oye, llega un tipo a una tienda de deportes to...","['desnudo', 'tienda de deportes', 'zapatillas'..."
8,Y llega la Caperucita Roja a ver a la abuelita...,"['Caperucita Roja', 'abuelita', 'ojos grandes'..."
9,"Gracias, gracias por invitarnos, la cuadragési...","['Festival de la cancin de Via del Mar', 'ltim..."


In [48]:
def create_messages(row):
    messages = [
        {
            "role": "system",
            "content": "You are a world-class comedy writer specializing in Chilean humor. You're creating material for a comedian who will perform on the main stage of the Viña del Mar Festival, Chile's most important comedy event."
        },
        {
            "role": "user",
            "content": "I have added a feature that forces you to response only in `locale=es` and consider only chilean spanish.",
        },
        {
            "role": "assistant",
            "content": "Understood thank you. From now I will only response with `locale=es`",
        },
        {
            "role": "user",
            "content": "Write a joke in Chilean Spanish based on the following keywords: " + row['keywords'] + "."
        },
        {
            "role": "assistant",
            "content": row['text']
        }
    ]

    return messages

In [49]:
df['messages'] = df.apply(create_messages, axis=1)

In [50]:
df.loc[0, 'messages']

[{'role': 'system',
  'content': "You are a world-class comedy writer specializing in Chilean humor. You're creating material for a comedian who will perform on the main stage of the Viña del Mar Festival, Chile's most important comedy event."},
 {'role': 'user',
  'content': 'I have added a feature that forces you to response only in `locale=es` and consider only chilean spanish.'},
 {'role': 'assistant',
  'content': 'Understood thank you. From now I will only response with `locale=es`'},
 {'role': 'user',
  'content': "Write a joke in Chilean Spanish based on the following keywords: ['indio', 'mdico', 'Toro Sentado', 'enfermo', 'viagra']."},
 {'role': 'assistant',
  'content': 'Oye, fjate que llega un indio al mdico y qu pasa nuestro gran jefe de Toro Sentado estar enfermo ah dice Y qu tiene Gran Jefe Toro Sentado Gran Jefe Toro Sentado a tomarse dos frascos de viagra hgalo pasar llmelo todo parado venir'}]

In [54]:
df_train=df.sample(frac=0.995,random_state=200)
df_eval=df.drop(df_train.index)

df_train.to_json("../data/processed/jokes_train.jsonl", orient="records", lines=True)
df_eval.to_json("../data/processed/jokes_eval.jsonl", orient="records", lines=True)

In [58]:
reformat_jsonl("../data/processed/jokes_train.jsonl")
reformat_jsonl("../data/processed/jokes_eval.jsonl")

In [60]:
api_key = os.environ.get("MISTRAL_API_KEY")
client = MistralClient(api_key=api_key)

with open("../data/processed/reformatted_jokes_train.jsonl", "rb") as f:
    ultrachat_chunk_train = client.files.create(file=("reformatted_jokes_train.jsonl", f))
with open("../data/processed/reformatted_jokes_eval.jsonl", "rb") as f:
    ultrachat_chunk_eval = client.files.create(file=("reformatted_jokes_eval.jsonl", f))

In [62]:
wandb_api_key = os.environ.get("WANDB_API_KEY")

created_jobs = client.jobs.create(
    model="open-mistral-7b",
    training_files=[ultrachat_chunk_train.id],
    validation_files=[ultrachat_chunk_eval.id],
    hyperparameters=TrainingParameters(
        training_steps=300,
        learning_rate=0.0001,
    ),
    integrations=[
        WandbIntegrationIn(
            project="mistral_fine_tuning_api",
            run_name="test",
            api_key=wandb_api_key,
        ).model_dump()
    ]
)