In [1]:
import os, asyncio, json
from pathlib import Path

import weave

from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

In [2]:
# Globals

DATA_PATH = Path("./data")
NUM_SAMPLES = 10 # Number of samples to use for evaluation, use None for all
PROJECT_NAME = "llm-judge-webinar"

In [3]:
def read_jsonl(path):
    "returns a list of dictionaries"
    with open(path, 'r') as file:
        return [json.loads(line) for line in file]

In [31]:
train_ds_usb = read_jsonl(DATA_PATH / "usb-train.jsonl")

train_ds = read_jsonl(DATA_PATH / "fib-train.jsonl")
val_ds = read_jsonl(DATA_PATH / "fib-val.jsonl")[0:NUM_SAMPLES]

In [6]:
client = MistralClient(api_key=os.environ["MISTRAL_API_KEY"])

@weave.op()  # <---- add this and you are good to go
def call_mistral(model:str, messages:list, **kwargs) -> str:
    "Call the Mistral API"
    chat_response = client.chat(
        model=model,
        messages=messages,
        response_format={"type": "json_object"},
        **kwargs,
    )
    return json.loads(chat_response.choices[0].message.content)

In [7]:
prompt = """You are an expert to detect factual inconsistencies and hallucinations. You will be given a document and a summary.
- Carefully read the full document and the provided summary.
- Identify Factual Inconsistencies: any statements in the summary that are not supported by or contradict the information in the document.
Factually Inconsistent: If any statement in the summary is not supported by or contradicts the document, label it as 0
Factually Consistent: If all statements in the summary are supported by the document, label it as 1

Highlight or list the specific statements in the summary that are inconsistent.
Provide a brief explanation of why each highlighted statement is inconsistent with the document.

Return in JSON format with `consistency` and a `reason` for the given choice.

Document: 
{premise}
Summary: 
{hypothesis}
"""

In [9]:
def format_prompt(prompt, premise: str, hypothesis: str, cls=ChatMessage):
    messages = [
        cls(
            role="user", 
            content=prompt.format(premise=premise, hypothesis=hypothesis)
        )
    ]
    return messages

In [10]:
weave.init(PROJECT_NAME)

weave version 0.50.7 is available!  To upgrade, please run:
 $ pip install weave --upgrade
Logged in as Weights & Biases user: capecape.
View Weave data at https://wandb.ai/capecape/llm-judge-webinar/weave


<weave.weave_client.WeaveClient at 0x15fca67d0>

In [11]:
class MistralModel(weave.Model):
    model: str
    prompt: str
    temperature: float = 0.7
    
    @weave.op
    def create_messages(self, premise:str, hypothesis:str):
        return format_prompt(self.prompt, premise, hypothesis)

    @weave.op
    def predict(self, premise:str, hypothesis:str):
        messages = self.create_messages(premise, hypothesis)
        return call_mistral(model=self.model, messages=messages, temperature=self.temperature)

## Eval

In [12]:
def accuracy(model_output, target):
    class_model_output = model_output.get('consistency') if model_output else None
    return {"accuracy": class_model_output == target}

In [13]:
class BinaryMetrics(weave.Scorer):
    class_name: str
    eps: float = 1e-8

    @weave.op()
    def summarize(self, score_rows) -> dict:
        # filter out None rows, model may error out sometimes...
        score_rows = [score for score in score_rows if score["correct"] is not None]
        # Compute f1, precision, recall
        tp = sum([not score["negative"] and score["correct"] for score in score_rows])
        fp = sum([not score["negative"] and not score["correct"] for score in score_rows])
        fn = sum([score["negative"] and not score["correct"] for score in score_rows])
        precision = tp / (tp + fp + self.eps)
        recall = tp / (tp + fn + self.eps)
        f1 = 2 * precision * recall / (precision + recall + self.eps)
        result = {"f1": f1, "precision": precision, "recall": recall}
        return result

    @weave.op()
    def score(self, target: dict, model_output: dict) -> dict:
        class_model_output = model_output.get(self.class_name) if model_output else None  # 0 or 1
        result = {
            "correct": class_model_output == target,
            "negative": not class_model_output,
        }
        return result

F1 = BinaryMetrics(class_name="consistency")

In [14]:
evaluation = weave.Evaluation(dataset=val_ds, scorers=[accuracy, F1])

## Fine-Tune FTW

This is pretty descent for both 😍. Let's see if fine-tuning improves this.

In [15]:
ft_prompt = """You are an expert to detect factual inconsistencies and hallucinations. You will be given a document and a summary.
- Carefully read the full document and the provided summary.
- Identify Factual Inconsistencies: any statements in the summary that are not supported by or contradict the information in the document.
Factually Inconsistent: If any statement in the summary is not supported by or contradicts the document, label it as 0
Factually Consistent: If all statements in the summary are supported by the document, label it as 1

Return in JSON format with `consistency` for the given choice.

Document: 
{premise}
Summary: 
{hypothesis}
"""

answer = """{{"consistency": {label}}}"""

In [16]:
def format_prompt_ft(row, cls=dict, with_answer=True):
    "Format on the expected MistralAI fine-tuning dataset"
    premise = row['premise']
    hypothesis = row['hypothesis']
    messages = [
        cls(
            role="user", 
            content=prompt.format(premise=premise, hypothesis=hypothesis)
        )
    ]
    if with_answer:
        label = row['target']
        messages.append(
            cls(
                role="assistant",
            content=answer.format(label=label)
            )
        )
    return messages

In [17]:
format_prompt_ft(train_ds[0])

[{'role': 'user',
  'content': 'You are an expert to detect factual inconsistencies and hallucinations. You will be given a document and a summary.\n- Carefully read the full document and the provided summary.\n- Identify Factual Inconsistencies: any statements in the summary that are not supported by or contradict the information in the document.\nFactually Inconsistent: If any statement in the summary is not supported by or contradicts the document, label it as 0\nFactually Consistent: If all statements in the summary are supported by the document, label it as 1\n\nHighlight or list the specific statements in the summary that are inconsistent.\nProvide a brief explanation of why each highlighted statement is inconsistent with the document.\n\nReturn in JSON format with `consistency` and a `reason` for the given choice.\n\nDocument: \nWendy Jane Crewson Crewson was born in Hamilton, Ontario, the daughter of June Doreen (née Thomas) and Robert Binnie Crewson. Also in 2012, Crewson bega

In [18]:
formatted_train_ds = [format_prompt_ft(row) for row in train_ds]
formatted_val_ds = [format_prompt_ft(row) for row in val_ds]

In [20]:
def save_jsonl(ds, path):
    with open(path, "w") as f:
        for row in ds:
            f.write(json.dumps(row) + "\n")
save_jsonl(formatted_train_ds, DATA_PATH/"formatted_train_usb.jsonl")

## Upload dataset

In [36]:
with open(DATA_PATH/"formatted_train_usb.jsonl", "rb") as f:
    ds_train_usb = client.files.create(file=("formatted_df_train_usb.jsonl", f))
with open(DATA_PATH/"formatted_train.jsonl", "rb") as f:
    ds_train_fib = client.files.create(file=("formatted_df_train.jsonl", f))

In [37]:
import json
def pprint(obj):
    print(json.dumps(obj.dict(), indent=4))

In [41]:
pprint(ds_train_usb)
pprint(ds_train_fib)

{
    "id": "68c58f46-0704-4d53-945a-6e62953703dc",
    "object": "file",
    "bytes": 7711646,
    "created_at": 1719954434,
    "filename": "formatted_df_train_usb.jsonl",
    "purpose": "fine-tune"
}
{
    "id": "b6d9b3cb-08d1-47e7-8ef8-f53a3129ae08",
    "object": "file",
    "bytes": 7711646,
    "created_at": 1719954436,
    "filename": "formatted_df_train.jsonl",
    "purpose": "fine-tune"
}


## Create a fine-tuning job on USB dataset

we pass both datasets as we can not do a 2 stage fine-tune. This may actually be better, and Eugene suggests to try this.

In [45]:
from mistralai.models.jobs import TrainingParameters, WandbIntegrationIn

created_jobs = client.jobs.create(
    model="open-mistral-7b",
    training_files=[ds_train_usb.id, ds_train_fib.id],
    validation_files=None,
    hyperparameters=TrainingParameters(
        training_steps=100,
        learning_rate=0.0001,
        ),
    integrations=[
        WandbIntegrationIn(
            project=PROJECT_NAME,
            run_name="finetune-usb",
            api_key=os.environ.get("WANDB_API_KEY"),
        ).dict()
    ],
    # dry_run=True,
)

In [46]:
pprint(created_jobs)

{
    "id": "a6a4c3c2-b6dc-431a-aa04-7031596fcdcd",
    "hyperparameters": {
        "training_steps": 100,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": null,
    "model": "open-mistral-7b",
    "status": "QUEUED",
    "job_type": "FT",
    "created_at": 1719954586,
    "modified_at": 1719954586,
    "training_files": [
        "68c58f46-0704-4d53-945a-6e62953703dc",
        "b6d9b3cb-08d1-47e7-8ef8-f53a3129ae08"
    ],
    "validation_files": [],
    "object": "job",
    "integrations": [
        {
            "type": "wandb",
            "project": "llm-judge-webinar",
            "name": null,
            "run_name": "finetune-usb"
        }
    ]
}


In [47]:
import time

retrieved_job = client.jobs.retrieve(created_jobs.id)
while retrieved_job.status in ["RUNNING", "QUEUED"]:
    retrieved_job = client.jobs.retrieve(created_jobs.id)
    pprint(retrieved_job)
    print(f"Job is {retrieved_job.status}, waiting 10 seconds")
    time.sleep(10)



{
    "id": "a6a4c3c2-b6dc-431a-aa04-7031596fcdcd",
    "hyperparameters": {
        "training_steps": 100,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": null,
    "model": "open-mistral-7b",
    "status": "RUNNING",
    "job_type": "FT",
    "created_at": 1719954586,
    "modified_at": 1719954586,
    "training_files": [
        "68c58f46-0704-4d53-945a-6e62953703dc",
        "b6d9b3cb-08d1-47e7-8ef8-f53a3129ae08"
    ],
    "validation_files": [],
    "object": "job",
    "integrations": [
        {
            "type": "wandb",
            "project": "llm-judge-webinar",
            "name": null,
            "run_name": "finetune-usb"
        }
    ],
    "events": [
        {
            "name": "status-updated",
            "data": {
                "status": "RUNNING"
            },
            "created_at": 1719954586
        },
        {
            "name": "status-updated",
            "data": {
                "status": "QUEUED"
            },
            "cre

## Use a fine-tuned model

Let's compute the predictions using the fine-tuned 7B model

In [26]:
# jobs = client.jobs.list()
# retrieved_job = jobs.data[0]

In [34]:
mistral_usb_fib_ft_chkpt = retrieved_job.fine_tuned_model

In [29]:
mistral_7b_ft = MistralModel(prompt=ft_prompt, model=mistral_usb_fib_ft_chkpt)

In [30]:
await evaluation.evaluate(mistral_7b_ft)

🍩 https://wandb.ai/capecape/llm-judge-webinar/r/call/eb509400-be73-4f8f-a9dd-da83c096fd49


{'model_output': {'consistency': {'mean': 1.0}},
 'accuracy': {'accuracy': {'true_count': 5, 'true_fraction': 0.5}},
 'BinaryMetrics': {'f1': 0.6666666613333334,
  'precision': 0.49999999949999996,
  'recall': 0.9999999980000001},
 'model_latency': {'mean': 5.256523180007934}}

## Create a fine-tuning job on FIB dataset

In [35]:
from mistralai.models.jobs import TrainingParameters, WandbIntegrationIn

created_jobs = client.jobs.create(
    model=mistral_usb_ft_chkpt,
    training_files=[ds_train_fib.id],
    validation_files=None,
    hyperparameters=TrainingParameters(
        training_steps=50,
        learning_rate=0.0001,
        ),
    integrations=[
        WandbIntegrationIn(
            project=PROJECT_NAME,
            run_name="finetune-usb",
            api_key=os.environ.get("WANDB_API_KEY"),
        ).dict()
    ],
)

MistralAPIException: Status: 422. Message: {"detail": [{"type": "enum", "loc": ["body", "job_in", "model"], "msg": "Input should be 'open-mistral-7b' or 'mistral-small-latest'", "ctx": {"expected": "'open-mistral-7b' or 'mistral-small-latest'"}}]}

In [None]:
pprint(created_jobs)

{
    "id": "0decfcf2-5874-429f-9ebc-69749d046c3d",
    "hyperparameters": {
        "training_steps": 100,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": null,
    "model": "open-mistral-7b",
    "status": "QUEUED",
    "job_type": "FT",
    "created_at": 1719952396,
    "modified_at": 1719952396,
    "training_files": [
        "f94cf728-b341-484b-99c6-058085774784"
    ],
    "validation_files": [],
    "object": "job",
    "integrations": [
        {
            "type": "wandb",
            "project": "llm-judge-webinar",
            "name": null,
            "run_name": "finetune-usb"
        }
    ]
}


In [None]:
import time

retrieved_job = client.jobs.retrieve(created_jobs.id)
while retrieved_job.status in ["RUNNING", "QUEUED"]:
    retrieved_job = client.jobs.retrieve(created_jobs.id)
    pprint(retrieved_job)
    print(f"Job is {retrieved_job.status}, waiting 10 seconds")
    time.sleep(10)



{
    "id": "0decfcf2-5874-429f-9ebc-69749d046c3d",
    "hyperparameters": {
        "training_steps": 100,
        "learning_rate": 0.0001
    },
    "fine_tuned_model": null,
    "model": "open-mistral-7b",
    "status": "RUNNING",
    "job_type": "FT",
    "created_at": 1719952396,
    "modified_at": 1719952396,
    "training_files": [
        "f94cf728-b341-484b-99c6-058085774784"
    ],
    "validation_files": [],
    "object": "job",
    "integrations": [
        {
            "type": "wandb",
            "project": "llm-judge-webinar",
            "name": null,
            "run_name": "finetune-usb"
        }
    ],
    "events": [
        {
            "name": "status-updated",
            "data": {
                "status": "RUNNING"
            },
            "created_at": 1719952396
        },
        {
            "name": "status-updated",
            "data": {
                "status": "QUEUED"
            },
            "created_at": 1719952396
        }
    ],
    "check

## Use a fine-tuned model

Let's compute the predictions using the fine-tuned 7B model

In [None]:
# jobs = client.jobs.list()
# retrieved_job = jobs.data[0]

In [None]:
retrieved_job.fine_tuned_model

'ft:open-mistral-7b:0362203c:20240702:0decfcf2'

In [None]:
mistral_7b_ft = MistralModel(prompt=ft_prompt, model=retrieved_job.fine_tuned_model)

In [None]:
await evaluation.evaluate(mistral_7b_ft)

🍩 https://wandb.ai/capecape/llm-judge-webinar/r/call/eb509400-be73-4f8f-a9dd-da83c096fd49


{'model_output': {'consistency': {'mean': 1.0}},
 'accuracy': {'accuracy': {'true_count': 5, 'true_fraction': 0.5}},
 'BinaryMetrics': {'f1': 0.6666666613333334,
  'precision': 0.49999999949999996,
  'recall': 0.9999999980000001},
 'model_latency': {'mean': 5.256523180007934}}