In [5]:
%load_ext autoreload
%autoreload 2

from loguru import logger
import os
import random
from mistralai.client import MistralClient
from mistralai.models.jobs import WandbIntegrationIn, TrainingParameters

from mistral_fine_tuning.utils import read_fine_tuning_file
from mistral_fine_tuning.reformat import reformat_jsonl

from dotenv import load_dotenv
load_dotenv()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


True

In [6]:
df = read_fine_tuning_file('../data/interim/quality_jokes_fine_tuning.jsonl')

In [7]:
def extract_random_element(row):
    return random.choice(row['keywords'].split(',')).replace('[', '').replace(']', '')

# Apply the function to the DataFrame
df['one_keyword'] = df.apply(extract_random_element, axis=1)

In [8]:
df.head()

Unnamed: 0,text,keywords,one_keyword
0,"Oye, fjate que llega un indio al mdico y qu pa...","['indio', 'mdico', 'Toro Sentado', 'enfermo', ...",'viagra'
1,y llega el tema del medio que le dice doctor m...,"['luna de viernes', 'seora', 'tres pechos', 'd...",'doctor'
2,"Oye, llega un tipo a una tienda de deportes to...","['desnudo', 'tienda de deportes', 'zapatillas'...",'descuento'
3,Y llega la Caperucita Roja a ver a la abuelita...,"['Caperucita Roja', 'abuelita', 'ojos grandes'...",'Caperucita Roja'
4,"La situación, la crisis iniciática no tiene má...","['crisis', 'agua', 'playa', 'edificio en const...",'edificio en construcción'


In [16]:
df['one_keyword'].value_counts()

one_keyword
 'familia'             11
 'hijo'                10
 'dinero'               9
 'hijos'                8
 'amigos'               7
                       ..
'micro'                 1
 'reuniones'            1
 'pasta de dientes'     1
 'sabio'                1
'chocolate laxante'     1
Name: count, Length: 1838, dtype: int64

In [9]:
def create_messages(row):
    messages = [
        {
            "role": "system",
            "content": "You are a world-class comedy writer specializing in Chilean humor. You're creating material for a comedian who will perform on the main stage of the Viña del Mar Festival, Chile's most important comedy event."
        },
        {
            "role": "user",
            "content": "I have added a feature that forces you to response only in `locale=es` and consider only chilean spanish.",
        },
        {
            "role": "assistant",
            "content": "Understood thank you. From now I will only response with `locale=es`",
        },
        {
            "role": "user",
            "content": "Write a joke in Chilean Spanish based on the following keyword: " + row['one_keyword'] + "."
        },
        {
            "role": "assistant",
            "content": row['text']
        }
    ]

    return messages

In [10]:
df['messages'] = df.apply(create_messages, axis=1)

In [11]:
df.loc[0, 'messages']

[{'role': 'system',
  'content': "You are a world-class comedy writer specializing in Chilean humor. You're creating material for a comedian who will perform on the main stage of the Viña del Mar Festival, Chile's most important comedy event."},
 {'role': 'user',
  'content': 'I have added a feature that forces you to response only in `locale=es` and consider only chilean spanish.'},
 {'role': 'assistant',
  'content': 'Understood thank you. From now I will only response with `locale=es`'},
 {'role': 'user',
  'content': "Write a joke in Chilean Spanish based on the following keyword:  'viagra'."},
 {'role': 'assistant',
  'content': 'Oye, fjate que llega un indio al mdico y qu pasa nuestro gran jefe de Toro Sentado estar enfermo ah dice Y qu tiene Gran Jefe Toro Sentado Gran Jefe Toro Sentado a tomarse dos frascos de viagra hgalo pasar llmelo todo parado venir'}]

In [12]:
df_train=df.sample(frac=0.995,random_state=200)
df_eval=df.drop(df_train.index)

df_train.to_json("../data/processed/quality_jokes_train.jsonl", orient="records", lines=True)
df_eval.to_json("../data/processed/quality_jokes_eval.jsonl", orient="records", lines=True)

In [13]:
reformat_jsonl("../data/processed/quality_jokes_train.jsonl")
reformat_jsonl("../data/processed/quality_jokes_eval.jsonl")

In [14]:
api_key = os.environ.get("MISTRAL_API_KEY")
client = MistralClient(api_key=api_key)

with open("../data/processed/reformatted_quality_jokes_train.jsonl", "rb") as f:
    quality_jokes_train = client.files.create(file=("reformatted_quality_jokes_train.jsonl", f))
with open("../data/processed/reformatted_quality_jokes_eval.jsonl", "rb") as f:
    quality_jokes_eval = client.files.create(file=("reformatted_quality_jokes_eval.jsonl", f))

In [15]:
wandb_api_key = os.environ.get("WANDB_API_KEY")

created_jobs = client.jobs.create(
    model="open-mistral-7b",
    training_files=[quality_jokes_train.id],
    validation_files=[quality_jokes_eval.id],
    hyperparameters=TrainingParameters(
        training_steps=300,
        learning_rate=0.0001,
    ),
    integrations=[
        WandbIntegrationIn(
            project="mistral_fine_tuning_api",
            run_name="test quality",
            api_key=wandb_api_key,
        ).model_dump()
    ]
)