<a href="https://colab.research.google.com/github/AlexChalakov/a16z-hackathon/blob/main/Mistral_finetune_api.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine tuning Mistral Large model

(Notebook adapted from: https://docs.mistral.ai/capabilities/finetuning/)

In [3]:
!pip install mistralai pandas

/bin/bash: line 1: brew: command not found


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [9]:
import pandas as pd
import json

## Prepare EmpatheticDiaglogues dataset

https://dl.fbaipublicfiles.com/parlai/empatheticdialogues/empatheticdialogues.tar.gz

In [7]:
train_df = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/Mistral AI hackathon/empatheticdialogues/train.csv')
valid_df = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/Mistral AI hackathon/empatheticdialogues/valid.csv')
test_df = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/Mistral AI hackathon/empatheticdialogues/test.csv')

In [8]:
train_df.head()

Unnamed: 0,conv_id,utterance_idx,context,prompt,speaker_idx,utterance,selfeval
0,hit:0_conv:1,1,sentimental,I remember going to the fireworks with my best...,1,I remember going to see the fireworks with my ...,5|5|5_2|2|5
1,hit:0_conv:1,2,sentimental,I remember going to the fireworks with my best...,0,Was this a friend you were in love with_comma_...,5|5|5_2|2|5
2,hit:0_conv:1,3,sentimental,I remember going to the fireworks with my best...,1,This was a best friend. I miss her.,5|5|5_2|2|5
3,hit:0_conv:1,4,sentimental,I remember going to the fireworks with my best...,0,Where has she gone?,5|5|5_2|2|5
4,hit:0_conv:1,5,sentimental,I remember going to the fireworks with my best...,1,We no longer talk.,5|5|5_2|2|5


In [30]:
def replace_comma(text):
    return text.replace('_comma_', ',')

In [33]:
train_df = train_df.dropna()
valid_df = valid_df.dropna()
test_df = test_df.dropna()

In [34]:
train_df['utterance'] = train_df['utterance'].apply(replace_comma)
valid_df['utterance'] = valid_df['utterance'].apply(replace_comma)
test_df['utterance'] = test_df['utterance'].apply(replace_comma)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_df['utterance'] = test_df['utterance'].apply(replace_comma)


In [32]:
replace_comma('Here is a_comma_ comma')

'Here is a, comma'

In [35]:
def convert_to_jsonl(df):
    conversations = []
    current_conv_id = None
    current_messages = []
    user_index = None

    for _, row in df.iterrows():
        conv_id = row['conv_id']
        utterance = row['utterance']
        speaker_idx = row['speaker_idx']

        # Determine role: alternate turns represent user and assistant
        if current_conv_id != conv_id:
            # Start a new conversation
            if len(current_messages) > 0:
                conversations.append({"messages": current_messages})

            current_conv_id = conv_id
            current_messages = []
            user_index = speaker_idx

        # Assign roles alternately based on speaker index sequence
        role = "user" if speaker_idx == user_index else "assistant"

        # Append the message to current conversation
        current_messages.append({"role": role, "content": utterance})


    return conversations

In [38]:
train_df_json = convert_to_jsonl(train_df)
valid_df_json = convert_to_jsonl(valid_df)
test_df_json = convert_to_jsonl(test_df)

In [37]:
test_df_json

[{'messages': [{'role': 'user',
    'content': 'Yeah about 10 years ago I had a horrifying experience. It was 100% their fault but they hit the water barrels and survived. They had no injuries but they almost ran me off the road.'},
   {'role': 'assistant', 'content': 'Did you suffer any injuries?'},
   {'role': 'user',
    'content': "No I wasn't hit. It turned out they were drunk. I felt guilty but realized it was his fault."},
   {'role': 'assistant',
    'content': "Why did you feel guilty? People really shouldn't drive drunk."},
   {'role': 'user',
    'content': "I don't know I was new to driving and hadn't experienced anything like that. I felt like my horn made him swerve into the water barrels."}]},
 {'messages': [{'role': 'user',
    'content': 'Well, can you tell me about your experience? I think we swapped places'},
   {'role': 'assistant',
    'content': 'Yeah i wanted to tell you about the time i was hit by a drunk driver im so happy to still be alive after that experienc

In [42]:
output_file = '/content/train.jsonl'

# Write the list of conversations to a JSONL file
with open(output_file, 'w') as jsonl_file:
    for conversation in train_df_json:
        jsonl_file.write(json.dumps(conversation) + '\n')

In [44]:
output_file = '/content/valid.jsonl'

# Write the list of conversations to a JSONL file
with open(output_file, 'w') as jsonl_file:
    for conversation in valid_df_json:
        jsonl_file.write(json.dumps(conversation) + '\n')

In [45]:
output_file = '/content/test.jsonl'

# Write the list of conversations to a JSONL file
with open(output_file, 'w') as jsonl_file:
    for conversation in test_df_json:
        jsonl_file.write(json.dumps(conversation) + '\n')

## Reformat dataset
If you upload this ultrachat_chunk_train.jsonl to Mistral API, you might encounter an error message “Invalid file format” due to data formatting issues. To reformat the data into the correct format, you can download the reformat_dataset.py script and use it to validate and reformat both the training and evaluation data:

In [39]:
# download the validation and reformat script
!wget https://raw.githubusercontent.com/mistralai/mistral-finetune/main/utils/reformat_data.py

--2024-10-06 10:00:23--  https://raw.githubusercontent.com/mistralai/mistral-finetune/main/utils/reformat_data.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3381 (3.3K) [text/plain]
Saving to: ‘reformat_data.py’


2024-10-06 10:00:24 (53.9 MB/s) - ‘reformat_data.py’ saved [3381/3381]



In [43]:
# validate and reformat the training data
!python reformat_data.py train.jsonl

Skipped 9th sample
Skipped 325th sample
Skipped 435th sample
Skipped 529th sample
Skipped 768th sample
Skipped 1052th sample
Skipped 1061th sample
Skipped 1083th sample
Skipped 1090th sample
Skipped 1091th sample
Skipped 1953th sample
Skipped 1954th sample
Skipped 2218th sample
Skipped 2774th sample
Skipped 3361th sample
Skipped 4214th sample
Skipped 4840th sample
Skipped 5255th sample
Skipped 5392th sample
Skipped 6922th sample
Skipped 7009th sample
Skipped 7920th sample
Skipped 8233th sample
Skipped 8520th sample
Skipped 8521th sample
Skipped 8749th sample
Skipped 8750th sample
Skipped 10053th sample
Skipped 11197th sample
Skipped 12209th sample
Skipped 12361th sample
Skipped 13044th sample
Skipped 13655th sample
Skipped 13855th sample
Skipped 14402th sample
Skipped 15219th sample
Skipped 15220th sample
Skipped 15723th sample
Skipped 15831th sample
Skipped 15832th sample
Skipped 16477th sample
Skipped 16539th sample
Skipped 16925th sample
Skipped 17200th sample
Skipped 17258th sample

In [46]:
# validate the reformat the eval data
!python reformat_data.py valid.jsonl

Skipped 1247th sample


In [None]:
df_train.iloc[3674]['messages']

array([{'content': 'What are the dimensions of the cavity, product, and shipping box of the Sharp SMC1662DS microwave?: With innovative features like preset controls, Sensor Cooking and the Carousel® turntable system, the Sharp® SMC1662DS 1.6 cu. Ft. Stainless Steel Carousel Countertop Microwave makes reheating your favorite foods, snacks and beverages easier than ever. Use popcorn and beverage settings for one-touch cooking. Express Cook allows one-touch cooking up to six minutes. The convenient and flexible "+30 Sec" key works as both instant start option and allows you to add more time during cooking.\nThe Sharp SMC1662DS microwave is a bold design statement in any kitchen. The elegant, grey interior and bright white, LED interior lighting complements the stainless steel finish of this premium appliance.\nCavity Dimensions (w x h x d): 15.5" x 10.2" x 17.1"\nProduct Dimensions (w x h x d): 21.8" x 12.8" x 17.7"\nShipping Dimensions (w x h x d) : 24.4" x 15.0" x 20.5"', 'role': 'user

## Upload dataset

In [52]:
from mistralai import Mistral
import os

api_key = userdata.get('MISTRAL_API_KEY') #os.environ["MISTRAL_API_KEY"]

client = Mistral(api_key=api_key)

train = client.files.upload(file={
    "file_name": "train.jsonl",
    "content": open("train.jsonl", "rb"),
})
valid = client.files.upload(file={
    "file_name": "valid.jsonl",
    "content": open("valid.jsonl", "rb"),
})

In [53]:
import json
def pprint(obj):
    print(json.dumps(obj.dict(), indent=4))

In [54]:
pprint(train)

{
    "id": "7666ee20-04d6-4426-a98d-d9f057a40d17",
    "object": "file",
    "bytes": 8645590,
    "created_at": 1728209570,
    "filename": "train.jsonl",
    "sample_type": "instruct",
    "source": "upload",
    "purpose": null,
    "num_lines": 18361
}


In [55]:
pprint(valid)

{
    "id": "e224b685-143f-4b1a-ab1d-d1e3e2041680",
    "object": "file",
    "bytes": 1369301,
    "created_at": 1728209571,
    "filename": "valid.jsonl",
    "sample_type": "instruct",
    "source": "upload",
    "purpose": null,
    "num_lines": 2761
}


## Create a fine-tuning job

In [58]:
created_jobs = client.fine_tuning.jobs.create(
    model="mistral-large-latest",#"pixtral-12b-2409",
    training_files=[{"file_id": train.id, "weight": 1}],
    validation_files=[valid.id],
    hyperparameters={
    "training_steps": 10,
    "learning_rate":0.001
    },
    auto_start=True
)
created_jobs

JobOut(id='2f360255-6813-49e9-aa88-db52f6c8df96', auto_start=True, hyperparameters=TrainingParameters(training_steps=10, learning_rate=0.001, weight_decay=0.1, warmup_fraction=0.05, epochs=None, fim_ratio=None), model='mistral-large-latest', status='QUEUED', job_type='FT', created_at=1728209855, modified_at=1728209855, training_files=['7666ee20-04d6-4426-a98d-d9f057a40d17'], validation_files=['e224b685-143f-4b1a-ab1d-d1e3e2041680'], OBJECT='job', fine_tuned_model=None, suffix=None, integrations=[], trained_tokens=None, repositories=[], metadata=JobMetadataOut(expected_duration_seconds=None, cost=0.0, cost_currency=None, train_tokens_per_step=None, train_tokens=None, data_tokens=None, estimated_start_time=None))

In [59]:
pprint(created_jobs)

{
    "id": "2f360255-6813-49e9-aa88-db52f6c8df96",
    "auto_start": true,
    "hyperparameters": {
        "training_steps": 10,
        "learning_rate": 0.001,
        "weight_decay": 0.1,
        "warmup_fraction": 0.05,
        "epochs": null,
        "fim_ratio": null
    },
    "model": "mistral-large-latest",
    "status": "QUEUED",
    "job_type": "FT",
    "created_at": 1728209855,
    "modified_at": 1728209855,
    "training_files": [
        "7666ee20-04d6-4426-a98d-d9f057a40d17"
    ],
    "validation_files": [
        "e224b685-143f-4b1a-ab1d-d1e3e2041680"
    ],
    "fine_tuned_model": null,
    "suffix": null,
    "integrations": [],
    "trained_tokens": null,
    "repositories": [],
    "metadata": {
        "expected_duration_seconds": null,
        "cost": 0.0,
        "cost_currency": null,
        "train_tokens_per_step": null,
        "train_tokens": null,
        "data_tokens": null,
        "estimated_start_time": null
    }
}


In [60]:
jobs = client.fine_tuning.jobs.list()
print(jobs)

total=22 data=[JobOut(id='2f360255-6813-49e9-aa88-db52f6c8df96', auto_start=True, hyperparameters=TrainingParameters(training_steps=10, learning_rate=0.001, weight_decay=0.1, warmup_fraction=0.05, epochs=1.2128452107595227, fim_ratio=None), model='mistral-large-latest', status='RUNNING', job_type='FT', created_at=1728209855, modified_at=1728209859, training_files=['7666ee20-04d6-4426-a98d-d9f057a40d17'], validation_files=['e224b685-143f-4b1a-ab1d-d1e3e2041680'], OBJECT='job', fine_tuned_model=None, suffix=None, integrations=[], trained_tokens=None, repositories=[], metadata=JobMetadataOut(expected_duration_seconds=900, cost=23.6, cost_currency='USD', train_tokens_per_step=262144, train_tokens=2621440, data_tokens=2161397, estimated_start_time=None)), JobOut(id='1efa1a26-50fe-49ac-89e2-f74c13f757aa', auto_start=False, hyperparameters=TrainingParameters(training_steps=10, learning_rate=0.0001, weight_decay=0.1, warmup_fraction=0.05, epochs=18.259603663845645, fim_ratio=None), model='mist

In [61]:
retrieved_jobs = client.fine_tuning.jobs.get(job_id = created_jobs.id)
retrieved_jobs

DetailedJobOut(id='2f360255-6813-49e9-aa88-db52f6c8df96', auto_start=True, hyperparameters=TrainingParameters(training_steps=10, learning_rate=0.001, weight_decay=0.1, warmup_fraction=0.05, epochs=1.2128452107595227, fim_ratio=None), model='mistral-large-latest', status='RUNNING', job_type='FT', created_at=1728209855, modified_at=1728209859, training_files=['7666ee20-04d6-4426-a98d-d9f057a40d17'], validation_files=['e224b685-143f-4b1a-ab1d-d1e3e2041680'], OBJECT='job', fine_tuned_model=None, suffix=None, integrations=[], trained_tokens=None, repositories=[], metadata=JobMetadataOut(expected_duration_seconds=900, cost=23.6, cost_currency='USD', train_tokens_per_step=262144, train_tokens=2621440, data_tokens=2161397, estimated_start_time=None), events=[EventOut(name='status-updated', created_at=1728209858, data={'status': 'RUNNING'}), EventOut(name='status-updated', created_at=1728209856, data={'status': 'QUEUED'}), EventOut(name='status-updated', created_at=1728209856, data={'status': '

In [62]:
import time

retrieved_job = client.fine_tuning.jobs.get(job_id = created_jobs.id)
while retrieved_job.status in ["RUNNING", "QUEUED"]:
    retrieved_job = client.fine_tuning.jobs.get(job_id = created_jobs.id)
    pprint(retrieved_job)
    print(f"Job is {retrieved_job.status}, waiting 10 seconds")
    time.sleep(10)



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
            }
        },
        {
            "name": "status-updated",
            "created_at": 1728209855,
            "data": {
                "status": "QUEUED"
            }
        }
    ],
    "checkpoints": []
}
Job is RUNNING, waiting 10 seconds
{
    "id": "2f360255-6813-49e9-aa88-db52f6c8df96",
    "auto_start": true,
    "hyperparameters": {
        "training_steps": 10,
        "learning_rate": 0.001,
        "weight_decay": 0.1,
        "warmup_fraction": 0.05,
        "epochs": 1.2128452107595227,
        "fim_ratio": null
    },
    "model": "mistral-large-latest",
    "status": "RUNNING",
    "job_type": "FT",
    "created_at": 1728209855,
    "modified_at": 1728209859,
    "training_files": [
        "7666ee20-04d6-4426-a98d-d9f057a40d17"
    ],
    "validation_files": [
        "e224b685-143f-4b1a-ab1d-d1e3e2041680"
    ],
    "fine_tuned_model": null,
    "suffix": null,
    "integrations": [],
    

In [64]:
# List jobs
jobs = client.fine_tuning.jobs.list()
pprint(jobs)

{
    "total": 22,
    "data": [
        {
            "id": "2f360255-6813-49e9-aa88-db52f6c8df96",
            "auto_start": true,
            "hyperparameters": {
                "training_steps": 10,
                "learning_rate": 0.001,
                "weight_decay": 0.1,
                "warmup_fraction": 0.05,
                "epochs": 1.2128452107595227,
                "fim_ratio": null
            },
            "model": "mistral-large-latest",
            "status": "SUCCESS",
            "job_type": "FT",
            "created_at": 1728209855,
            "modified_at": 1728210677,
            "training_files": [
                "7666ee20-04d6-4426-a98d-d9f057a40d17"
            ],
            "validation_files": [
                "e224b685-143f-4b1a-ab1d-d1e3e2041680"
            ],
            "fine_tuned_model": "ft:mistral-large-latest:5aa386c9:20241006:2f360255",
            "suffix": null,
            "integrations": [],
            "trained_tokens": 2621440,
       

In [65]:
# Retrieve a jobs
retrieved_jobs = client.fine_tuning.jobs.get(job_id = created_jobs.id)
pprint(retrieved_jobs)


{
    "id": "2f360255-6813-49e9-aa88-db52f6c8df96",
    "auto_start": true,
    "hyperparameters": {
        "training_steps": 10,
        "learning_rate": 0.001,
        "weight_decay": 0.1,
        "warmup_fraction": 0.05,
        "epochs": 1.2128452107595227,
        "fim_ratio": null
    },
    "model": "mistral-large-latest",
    "status": "SUCCESS",
    "job_type": "FT",
    "created_at": 1728209855,
    "modified_at": 1728210677,
    "training_files": [
        "7666ee20-04d6-4426-a98d-d9f057a40d17"
    ],
    "validation_files": [
        "e224b685-143f-4b1a-ab1d-d1e3e2041680"
    ],
    "fine_tuned_model": "ft:mistral-large-latest:5aa386c9:20241006:2f360255",
    "suffix": null,
    "integrations": [],
    "trained_tokens": 2621440,
    "repositories": [],
    "metadata": {
        "expected_duration_seconds": 900,
        "cost": 23.6,
        "cost_currency": "USD",
        "train_tokens_per_step": 262144,
        "train_tokens": 2621440,
        "data_tokens": 2161397,
    

## Use a fine-tuned model

In [None]:
chat_response = client.chat.complete(
    model = retrieved_jobs.fine_tuned_model,
    messages = [{"role":'user', "content":"I'm feeling down."}]
)

In [None]:
pprint(chat_response)

{
    "id": "1fac96713fd74799922712e34e009f81",
    "object": "chat.completion",
    "model": "ft:open-mistral-7b:b6e34a5e:20240719:20178c3c",
    "usage": {
        "prompt_tokens": 10,
        "completion_tokens": 73,
        "total_tokens": 83
    },
    "created": 1721405725,
    "choices": [
        {
            "index": 0,
            "finish_reason": "stop",
            "message": {
                "content": "There isn't a single \"best\" French cheese as there are hundreds of different types of cheese to choose from, each with its unique taste and texture. Some popular French cheeses include Brie, Camembert, Roquefort, Comt\u00e9, and Ch\u00e8vre. Try different cheeses to find out the one you like best!",
                "tool_calls": null,
                "prefix": false,
                "role": "assistant"
            }
        }
    ]
}


## Integration with Weights and Biases
We can also offer support for integration with Weights & Biases (W&B) to monitor and track various metrics and statistics associated with our fine-tuning jobs. To enable integration with W&B, you will need to create an account with W&B and add your W&B information in the “integrations” section in the job creation request:



In [None]:
client.fine_tuning.jobs.create(
    model="open-mistral-7b",
    training_files=[{"file_id": ultrachat_chunk_train.id, "weight": 1}],
    validation_files=[ultrachat_chunk_eval.id],
    hyperparameters={"training_steps": 10, "learning_rate": 0.0001},
    integrations=[
        {
            "project": "<value>",
            "api_key": "<value>",
        }
    ]
)