In [73]:
import os
import asyncio
import json
import weave
import pandas as pd
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai.models.jobs import TrainingParameters, WandbIntegrationIn
from mistralai.models.chat_completion import ChatMessage

from dotenv import load_dotenv
load_dotenv()


True

In [15]:
client = MistralAsyncClient(api_key=os.environ["MISTRAL_API_KEY"])

weave.init("mistral_hackathon")

@weave.op()
async def call_mistral(model:str, messages:list, **kwargs) -> str:
    "Call the Mistral API"
    chat_response = await client.chat(
        model=model,
        messages=messages,
        **kwargs,
    )
    return chat_response.choices[0].message.content

In [30]:
def create_messages(keyword: str, cls=ChatMessage):
    messages = [
        cls(
            role="user", 
            content=(
                "You are a world-class comedy writer specializing in Chilean humor."
                "You will write a joke in Chilean Spanish based on the keyword provided by the user."
                "Only output the joke, ignore any other explanation or context."
                "Write in Chilean Spanish."
                 )
        ),
        cls(
            role="assistant", 
            content=(
                "Sure, I'd be happy to help writing a new joke in Chilean Spanish.")
        ),
        cls(
            role="user", 
            content=f"Write a joke in Chilean Spanish based on the following keyword: {keyword}."
        )
    ]
    return messages

In [31]:
@weave.op()
async def humor_writer(keyword:str, model:str) -> str:
    "Write a new joke"
     
    messages = create_messages(keyword=keyword)

    joke = await call_mistral(model=model, messages=messages)
    return {"keyword": keyword, "joke": joke}



In [32]:
ds_eval = weave.ref('ds_eval:latest').get()

In [33]:
res = await humor_writer(keyword=ds_eval.rows[0]['keyword'], model="mistral-medium-latest")
print(ds_eval.rows[0]['keyword'])
print(res["joke"])

🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/fa320baf-7131-430e-80b2-0c79f6548046
pastor
Ahí va, aquí tienes un chiste en español chileno basado en la palabra "pastor":

¿Sabes por qué el pastor lleva siempre un reloj despertador?

Porque no quiere perderse la oportunidad de despertar a sus ovejas a tiempo para la misa matinal!

(Translation: Do you know why the shepherd always carries an alarm clock? Because he doesn't want to miss the chance to wake up his sheep on time for morning mass!)


In [43]:
class MistralModel(weave.Model):
    model: str
    temperature: float = 0.7
    
    @weave.op
    def create_messages(self, keyword:str):
        return create_messages(keyword)

    @weave.op
    async def predict(self, keyword:str):
        messages = self.create_messages(keyword)
        return await call_mistral(model=self.model, messages=messages)

In [34]:
mistral_medium = MistralModel(model="mistral-medium-latest")

In [39]:
async def async_foreach(sequence, func, max_concurrent_tasks):
    "Handy parallelism async for looper"
    semaphore = asyncio.Semaphore(max_concurrent_tasks)
    async def process_item(item):
        async with semaphore:
            result = await func(item)
            return item, result

    tasks = [asyncio.create_task(process_item(item)) for item in sequence]

    for task in asyncio.as_completed(tasks):
        item, result = await task
        yield item, result
        
async def map(ds, func, max_concurrent_tasks = 7, col_name="model_preds"):
    new_dataset = []
    async for example, map_results in async_foreach(ds.rows, func, max_concurrent_tasks):
        example.update({col_name: map_results})
        new_dataset.append(example)
    return new_dataset

ds_eval_medium_rows = await map(ds_eval, mistral_medium.predict, col_name="mistral_medium")

🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/02749644-4092-4fd3-b96c-e7fe33f8ce4d
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/eeeb1f35-6fbc-43ac-ae41-19ecaa9b83d6
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/eefc6b22-2c7e-41dc-b896-85d8bb7a543e
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/f067f13f-c7ce-48d7-9b20-e318680169ab
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/d71dc188-351f-478b-a292-2764bf39c2ca
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/7fb18d53-4a61-472b-b9cf-d2fdee73077b
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/83b6e4b2-60d5-4ad1-a8ae-9ba57088266f
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/a904ac21-2826-4a49-aeda-39d46565f441
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/01eee0d8-88cc-475f-8a22-92816758cf33
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/0baa9147-eaba-4b86-84b1-a2122b81dec3
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/f19718b2-09eb-4fa8-8e22-ae1c9e2080d9
🍩 https://

In [40]:
ds_eval_medium = weave.Dataset(name="ds_eval_medium", description="Mistral medium predictions", rows=ds_eval_medium_rows)
weave.publish(ds_eval_medium)

📦 Published to https://wandb.ai/aastroza/mistral_hackathon/weave/objects/ds_eval_medium/versions/AwEyU1MpEYMVDSvVxjH2eWM8SXfbLj5gNO1Wwy9r2Bw


ObjectRef(entity='aastroza', project='mistral_hackathon', name='ds_eval_medium', digest='AwEyU1MpEYMVDSvVxjH2eWM8SXfbLj5gNO1Wwy9r2Bw', extra=[])

In [41]:
ds_eval_medium = weave.ref('ds_eval_medium:latest').get()

In [44]:
mistral_7b = MistralModel(model="open-mistral-7b")
ds_eval_7b_rows = await map(ds_eval_medium, mistral_7b.predict, col_name="mistral_7b")
ds_eval_7b_medium = weave.Dataset(name="ds_eval_medium_7b", description="Mistral 7b predictions along with medium", rows=ds_eval_7b_rows)
weave.publish(ds_eval_7b_medium)

🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/066b3e58-71d9-4d89-a2d9-a003c08c9688
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/c094458e-d9f4-4124-a542-f3e705c1a0c2
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/5b8a8baa-c4fd-4f5e-b43c-2443155fd887
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/9566ad69-1741-440d-8da6-a828ea9e727c
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/1e2f1f64-3178-428a-83ca-f9cee2561752
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/7bc32437-64de-4c31-b092-1646aa6796eb
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/29cf5aec-4edf-48ff-8067-8c60cc371135
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/e6fa57f3-bf07-4b49-9524-e838140d2fdd
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/e4511872-e4d4-4d33-bf48-f16a863b367e
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/a094f806-8e9e-4791-b805-68d3575a1178
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/5a78f132-d5c5-4818-b368-59cbf78f8f73
🍩 https://

ObjectRef(entity='aastroza', project='mistral_hackathon', name='ds_eval_medium_7b', digest='iX7vdby09rE8oHt9hkBak0Ab0kzcgRzYWstYCBu9BYo', extra=[])

In [52]:
class LLMJudge(weave.Model):
    model: str = "mistral-large-latest"
    
    @weave.op
    async def predict(self, keyword: str, mistral_7b: str, mistral_medium: str, text: str, **kwargs) -> dict:
        messages = [
            ChatMessage(
                role="user",
                content=(
                "You are a world class comedian and you are judging a joke competition in Chile."
                "You have to pick the best joke between two jokes written about a keyword."
                "Take into consideration the jokes were written in Chilean Spanish and a ground truth joke as a reference. \n"
                "Here is the keyword: {keyword}\n"
                "Here is the joke1: {mistral_7b}\n"
                "Here is the joke2: {mistral_medium}\n"
                "Ground truth joke: {joke}\n"
                "Return the name of the best_joke (or None if you think both are bad) and the reason in short JSON object.").format(
                    keyword=keyword, 
                    mistral_7b=mistral_7b, 
                    mistral_medium=mistral_medium,
                    joke=text)
            )
        ]
        payload = await call_mistral(model=self.model, messages=messages, response_format={"type": "json_object"})
        return json.loads(payload)

In [53]:
ds_eval_7b_medium.rows[0].keys()

dict_keys(['text', 'keyword', 'mistral_medium', 'mistral_7b'])

In [54]:
llm_judge = LLMJudge()
res = await llm_judge.predict(**ds_eval_7b_medium.rows[0])
res

🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/3b503a9b-a59c-4a36-b95b-8f36c7776584


{'best_joke': 'joke2',
 'reason': "Joke2 uses a play on words with 'tesoro' and 'escavar' to create a humorous and suggestive comment about finding an attractive person at the beach. It is more clever and funny than Joke1, which simply states that searching is the best way to find a treasure."}

In [56]:
@weave.op
def evaluate_joke(model_output: str) -> dict:
    "Evaluate the answer"
    return {"win": model_output["best_joke"] == "joke1"}

In [57]:
evaluation = weave.Evaluation(dataset=ds_eval_7b_medium, scorers=[evaluate_joke])

In [58]:
await evaluation.evaluate(llm_judge)

Traceback (most recent call last):
  File "c:\Users\Alonso\Dropbox\personal\repos\mistral-fine-tuning\.venv\Lib\site-packages\weave\flow\eval.py", line 283, in eval_example
    eval_row = await self.predict_and_score(model, example)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Alonso\Dropbox\personal\repos\mistral-fine-tuning\.venv\Lib\site-packages\weave\trace\op.py", line 141, in _run_async
    output = await awaited_res
             ^^^^^^^^^^^^^^^^^
  File "c:\Users\Alonso\Dropbox\personal\repos\mistral-fine-tuning\.venv\Lib\site-packages\weave\flow\eval.py", line 221, in predict_and_score
    result = await async_call(score_fn, **score_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Alonso\.rye\py\cpython@3.12.1\install\Lib\asyncio\threads.py", line 25, in to_thread
    return await loop.run_in_executor(None, func_call)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Alonso\.rye\py\cpython@3.12.1\

🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/f45fe651-0171-4a40-a569-32e294aa847d


{'evaluate_joke': {'win': {'true_count': 47,
   'true_fraction': 0.43119266055045874}},
 'model_latency': {'mean': 26.51700864363154}}

In [64]:
def format_messages(row):
    "Format on the expected MistralAI fine-tuning dataset"
    keyword = row['keyword']
    joke = row['text']
    messages = create_messages(keyword, cls=dict)
    # we need to append the answer for training 👇
    messages = {"messages":messages + [dict(role="assistant", content=joke)]}
    return messages

In [65]:
df = pd.read_json('../data/processed/jokes.jsonl', lines=True)
df_train=df.sample(frac=0.95, random_state=200)
df_eval=df.drop(df_train.index)
len(df_train), len(df_eval)

(2093, 110)

In [67]:
formatted_df_train = df_train.apply(format_messages, axis=1)
formatted_df_eval = df_eval.apply(format_messages, axis=1)
formatted_df_train.head()

767     {'messages': [{'role': 'user', 'content': 'You...
538     {'messages': [{'role': 'user', 'content': 'You...
1637    {'messages': [{'role': 'user', 'content': 'You...
1666    {'messages': [{'role': 'user', 'content': 'You...
1759    {'messages': [{'role': 'user', 'content': 'You...
dtype: object

In [68]:
formatted_df_train.to_json("../data/processed/formatted_df_train.jsonl", orient="records", lines=True)
formatted_df_eval.to_json("../data/processed/formatted_df_eval.jsonl", orient="records", lines=True)

In [70]:
client = MistralClient(api_key=os.environ["MISTRAL_API_KEY"])

with open("../data/processed/formatted_df_train.jsonl", "rb") as f:
    ds_train = client.files.create(file=("formatted_df_train.jsonl", f))
with open("../data/processed/formatted_df_eval.jsonl", "rb") as f:
    ds_eval = client.files.create(file=("eval.jsonl", f))

In [71]:
def pprint(obj):
    print(json.dumps(obj.dict(), indent=4))

In [72]:
pprint(ds_train)

{
    "id": "d8718303-7865-47a0-8ab8-9cc209d1bd0a",
    "object": "file",
    "bytes": 1793233,
    "created_at": 1719803995,
    "filename": "formatted_df_train.jsonl",
    "purpose": "fine-tune"
}


In [74]:
created_jobs = client.jobs.create(
    model="open-mistral-7b",
    training_files=[ds_train.id],
    validation_files=[ds_eval.id],
    hyperparameters=TrainingParameters(
        training_steps=25,
        learning_rate=0.0001,
        ),
    integrations=[
        WandbIntegrationIn(
            project="mistral_hackathon",
            run_name="finetune_wandb",
            api_key=os.environ.get("WANDB_API_KEY"),
        ).dict()
    ],
)

In [75]:
pprint(created_jobs)

{
    "id": "ee377e19-d646-4d7f-b3f8-429dd318f606",
    "hyperparameters": {
        "training_steps": 25,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": null,
    "model": "open-mistral-7b",
    "status": "QUEUED",
    "job_type": "FT",
    "created_at": 1719804119,
    "modified_at": 1719804119,
    "training_files": [
        "d8718303-7865-47a0-8ab8-9cc209d1bd0a"
    ],
    "validation_files": [
        "77977d54-1be9-4dc3-8143-7a5e81779c6c"
    ],
    "object": "job",
    "integrations": [
        {
            "type": "wandb",
            "project": "mistral_hackathon",
            "name": null,
            "run_name": "finetune_wandb"
        }
    ]
}


In [76]:
import time

retrieved_job = client.jobs.retrieve(created_jobs.id)
while retrieved_job.status in ["RUNNING", "QUEUED"]:
    retrieved_job = client.jobs.retrieve(created_jobs.id)
    pprint(retrieved_job)
    print(f"Job is {retrieved_job.status}, waiting 10 seconds")
    time.sleep(10)

{
    "id": "ee377e19-d646-4d7f-b3f8-429dd318f606",
    "hyperparameters": {
        "training_steps": 25,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": null,
    "model": "open-mistral-7b",
    "status": "RUNNING",
    "job_type": "FT",
    "created_at": 1719804119,
    "modified_at": 1719804120,
    "training_files": [
        "d8718303-7865-47a0-8ab8-9cc209d1bd0a"
    ],
    "validation_files": [
        "77977d54-1be9-4dc3-8143-7a5e81779c6c"
    ],
    "object": "job",
    "integrations": [
        {
            "type": "wandb",
            "project": "mistral_hackathon",
            "name": null,
            "run_name": "finetune_wandb"
        }
    ],
    "events": [
        {
            "name": "status-updated",
            "data": {
                "status": "RUNNING"
            },
            "created_at": 1719804120
        },
        {
            "name": "status-updated",
            "data": {
                "status": "QUEUED"
            },
           

In [78]:
retrieved_jobs = client.jobs.retrieve(created_jobs.id)
pprint(retrieved_jobs)

{
    "id": "ee377e19-d646-4d7f-b3f8-429dd318f606",
    "hyperparameters": {
        "training_steps": 25,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": "ft:open-mistral-7b:fd6d41e7:20240701:ee377e19",
    "model": "open-mistral-7b",
    "status": "SUCCESS",
    "job_type": "FT",
    "created_at": 1719804119,
    "modified_at": 1719804310,
    "training_files": [
        "d8718303-7865-47a0-8ab8-9cc209d1bd0a"
    ],
    "validation_files": [
        "77977d54-1be9-4dc3-8143-7a5e81779c6c"
    ],
    "object": "job",
    "integrations": [
        {
            "type": "wandb",
            "project": "mistral_hackathon",
            "name": null,
            "run_name": "finetune_wandb"
        }
    ],
    "events": [
        {
            "name": "status-updated",
            "data": {
                "status": "SUCCESS"
            },
            "created_at": 1719804310
        },
        {
            "name": "status-updated",
            "data": {
                "s

In [77]:
ds_eval_medium = weave.ref('ds_eval_medium:latest').get()

In [81]:
client = MistralAsyncClient(api_key=os.environ["MISTRAL_API_KEY"])

In [82]:
mistral_7b_ft = MistralModel(model=retrieved_jobs.fine_tuned_model)
ds_eval_7b_rows = await map(ds_eval_medium, mistral_7b_ft.predict, col_name="mistral_7b")
ds_eval_7b_ft_medium = weave.Dataset(name="ds_eval_medium_7b_ft", description="Finetuned Mistral 7b predictions along with medium", rows=ds_eval_7b_rows)
weave.publish(ds_eval_7b_ft_medium)

🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/84da3856-0602-4f8a-9c79-7a215ae6e750
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/2e6dfc57-0ba4-4cf3-9288-b33d65559c2a
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/407fda3c-2125-47c2-be48-1eeeb669d315
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/477c61b2-e382-42f4-94f7-1674b0a99d8c
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/abb9c6fc-883b-4f66-8f50-08a9856f165f
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/c2f99fdd-61ca-40d5-ab77-9cfb1cf9ed48
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/1d61acf1-b3de-4f08-ba6c-e3855c49da27
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/968c362b-725f-4355-8353-10c74e9e9463
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/f0504664-875c-4ae4-a41e-13fe3e711824
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/1fc16902-3d3e-4333-9edc-b0795e7c1c97
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/60a7ea78-f269-43b3-914f-e48c8d98c8fd
🍩 https://

MistralException: Unexpected exception (ReadError): 

🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/8b5bcd48-e9b8-439b-9c2a-0173ebcf4e38


🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/dcb3a803-b510-4ed1-80f7-b64b076de4bc
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/a4d664c1-f59c-49a0-837f-b66de3d5c9ec
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/2888f849-4bb2-4e93-9544-b83d7b73153a
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/51ab5766-f7c2-4de4-9a7f-501fa8a487a4
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/e6310cd9-bb66-4173-9c54-7925a90256eb
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/7504b70a-9f3c-4e82-a24f-d33d394c0774
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/84cda7e9-a200-4677-821b-5fefd07bd2ef
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/bbbb6bc8-26db-4beb-95eb-939d08452fd7
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/46b457dd-d73f-43e7-9a9f-4775a91f8e9e
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/de86a67c-a702-4cdf-bc94-73dffe54f07c
🍩 https://wandb.ai/aastroza/mistral_hackathon/r/call/fae49a4e-fca6-4127-b6cf-f459a1811ce3
🍩 https://

In [None]:
evaluation = weave.Evaluation(dataset=ds_eval_7b_ft_medium, scorers=[evaluate_joke])

In [None]:
await evaluation.evaluate(llm_judge)