In [None]:
import os
import asyncio
import json
import weave
import pandas as pd
import time
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai.models.jobs import TrainingParameters, WandbIntegrationIn
from mistralai.models.chat_completion import ChatMessage

from dotenv import load_dotenv
load_dotenv()


In [None]:
client = MistralAsyncClient(api_key=os.environ["MISTRAL_API_KEY"])

weave.init("mistral_hackathon")

@weave.op()
async def call_mistral(model:str, messages:list, **kwargs) -> str:
    "Call the Mistral API"
    chat_response = await client.chat(
        model=model,
        messages=messages,
        **kwargs,
    )
    return chat_response.choices[0].message.content

In [None]:
def create_messages(keyword: str, cls=ChatMessage):
    messages = [
        cls(
            role="user", 
            content=(
                "You are a world-class comedy writer specializing in Chilean humor."
                "You will write a joke in Chilean Spanish based on the keyword provided by the user."
                "Only output the joke, ignore any other explanation or context."
                "Write in Chilean Spanish."
                 )
        ),
        cls(
            role="assistant", 
            content=(
                "Sure, I'd be happy to help writing a new joke in Chilean Spanish.")
        ),
        cls(
            role="user", 
            content=f"Write a joke in Chilean Spanish based on the following keyword: {keyword}."
        )
    ]
    return messages

In [None]:
@weave.op()
async def humor_writer(keyword:str, model:str) -> str:
    "Write a new joke"
     
    messages = create_messages(keyword=keyword)

    joke = await call_mistral(model=model, messages=messages)
    return {"keyword": keyword, "joke": joke}



In [None]:
ds_eval = weave.ref('ds_eval:latest').get()

In [None]:
res = await humor_writer(keyword=ds_eval.rows[0]['keyword'], model="mistral-medium-latest")
print(ds_eval.rows[0]['keyword'])
print(res["joke"])

In [None]:
class MistralModel(weave.Model):
    model: str
    temperature: float = 0.7
    
    @weave.op
    def create_messages(self, keyword:str):
        return create_messages(keyword)

    @weave.op
    async def predict(self, keyword:str):
        messages = self.create_messages(keyword)
        return await call_mistral(model=self.model, messages=messages)

In [None]:
mistral_medium = MistralModel(model="mistral-medium-latest")

In [None]:
async def async_foreach(sequence, func, max_concurrent_tasks):
    "Handy parallelism async for looper"
    semaphore = asyncio.Semaphore(max_concurrent_tasks)
    async def process_item(item):
        async with semaphore:
            result = await func(item)
            return item, result

    tasks = [asyncio.create_task(process_item(item)) for item in sequence]

    for task in asyncio.as_completed(tasks):
        item, result = await task
        yield item, result
        
async def map(ds, func, max_concurrent_tasks = 7, col_name="model_preds"):
    new_dataset = []
    async for example, map_results in async_foreach(ds.rows, func, max_concurrent_tasks):
        example.update({col_name: map_results})
        new_dataset.append(example)
    return new_dataset

ds_eval_medium_rows = await map(ds_eval, mistral_medium.predict, col_name="mistral_medium")

In [None]:
ds_eval_medium = weave.Dataset(name="ds_eval_medium", description="Mistral medium predictions", rows=ds_eval_medium_rows)
weave.publish(ds_eval_medium)

In [None]:
ds_eval_medium = weave.ref('ds_eval_medium:latest').get()

In [None]:
mistral_7b = MistralModel(model="open-mistral-7b")
ds_eval_7b_rows = await map(ds_eval_medium, mistral_7b.predict, col_name="mistral_7b")
ds_eval_7b_medium = weave.Dataset(name="ds_eval_medium_7b", description="Mistral 7b predictions along with medium", rows=ds_eval_7b_rows)
weave.publish(ds_eval_7b_medium)

In [None]:
class LLMJudge(weave.Model):
    model: str = "mistral-large-latest"
    
    @weave.op
    async def predict(self, keyword: str, mistral_7b: str, mistral_medium: str, text: str, **kwargs) -> dict:
        messages = [
            ChatMessage(
                role="user",
                content=(
                "You are a world class comedian and you are judging a joke competition in Chile."
                "You have to pick the best joke between two jokes written about a keyword."
                "Take into consideration the jokes were written in Chilean Spanish and a ground truth joke as a reference. \n"
                "Here is the keyword: {keyword}\n"
                "Here is the joke1: {mistral_7b}\n"
                "Here is the joke2: {mistral_medium}\n"
                "Ground truth joke: {joke}\n"
                "Return the name of the best_joke (or None if you think both are bad) and the reason in short JSON object.").format(
                    keyword=keyword, 
                    mistral_7b=mistral_7b, 
                    mistral_medium=mistral_medium,
                    joke=text)
            )
        ]
        payload = await call_mistral(model=self.model, messages=messages, response_format={"type": "json_object"})
        return json.loads(payload)

In [None]:
ds_eval_7b_medium.rows[0].keys()

In [None]:
llm_judge = LLMJudge()
res = await llm_judge.predict(**ds_eval_7b_medium.rows[0])
res

In [None]:
@weave.op
def evaluate_joke(model_output: str) -> dict:
    "Evaluate the answer"
    return {"win": model_output["best_joke"] == "joke1"}

In [None]:
evaluation = weave.Evaluation(dataset=ds_eval_7b_medium, scorers=[evaluate_joke])

In [None]:
await evaluation.evaluate(llm_judge)

In [None]:
def format_messages(row):
    "Format on the expected MistralAI fine-tuning dataset"
    keyword = row['keyword']
    joke = row['text']
    messages = create_messages(keyword, cls=dict)
    # we need to append the answer for training 👇
    messages = {"messages":messages + [dict(role="assistant", content=joke)]}
    return messages

In [None]:
df = pd.read_json('../data/processed/jokes.jsonl', lines=True)
df_train=df.sample(frac=0.95, random_state=200)
df_eval=df.drop(df_train.index)
len(df_train), len(df_eval)

In [None]:
formatted_df_train = df_train.apply(format_messages, axis=1)
formatted_df_eval = df_eval.apply(format_messages, axis=1)
formatted_df_train.head()

In [None]:
formatted_df_train.to_json("../data/processed/formatted_df_train.jsonl", orient="records", lines=True)
formatted_df_eval.to_json("../data/processed/formatted_df_eval.jsonl", orient="records", lines=True)

In [None]:
client = MistralClient(api_key=os.environ["MISTRAL_API_KEY"])

with open("../data/processed/formatted_df_train.jsonl", "rb") as f:
    ds_train = client.files.create(file=("formatted_df_train.jsonl", f))
with open("../data/processed/formatted_df_eval.jsonl", "rb") as f:
    ds_eval = client.files.create(file=("eval.jsonl", f))

In [None]:
def pprint(obj):
    print(json.dumps(obj.dict(), indent=4))

In [None]:
pprint(ds_train)

In [None]:
created_jobs = client.jobs.create(
    model="open-mistral-7b",
    training_files=[ds_train.id],
    validation_files=[ds_eval.id],
    hyperparameters=TrainingParameters(
        training_steps=25,
        learning_rate=0.0001,
        ),
    integrations=[
        WandbIntegrationIn(
            project="mistral_hackathon",
            run_name="finetune_wandb",
            api_key=os.environ.get("WANDB_API_KEY"),
        ).dict()
    ],
)

In [None]:
pprint(created_jobs)

In [None]:


retrieved_job = client.jobs.retrieve(created_jobs.id)
while retrieved_job.status in ["RUNNING", "QUEUED"]:
    retrieved_job = client.jobs.retrieve(created_jobs.id)
    pprint(retrieved_job)
    print(f"Job is {retrieved_job.status}, waiting 10 seconds")
    time.sleep(10)

In [None]:
retrieved_jobs = client.jobs.retrieve(created_jobs.id)
pprint(retrieved_jobs)

In [None]:
ds_eval_medium = weave.ref('ds_eval_medium:latest').get()

In [None]:
client = MistralAsyncClient(api_key=os.environ["MISTRAL_API_KEY"])

In [None]:
mistral_7b_ft = MistralModel(model=retrieved_jobs.fine_tuned_model)
ds_eval_7b_rows = await map(ds_eval_medium, mistral_7b_ft.predict, col_name="mistral_7b")
ds_eval_7b_ft_medium = weave.Dataset(name="ds_eval_medium_7b_ft", description="Finetuned Mistral 7b predictions along with medium", rows=ds_eval_7b_rows)
weave.publish(ds_eval_7b_ft_medium)

In [None]:
evaluation = weave.Evaluation(dataset=ds_eval_7b_ft_medium, scorers=[evaluate_joke])

In [None]:
await evaluation.evaluate(llm_judge)