# Fine-tuning Mistral 7B Instruct for OCR Correction

Purpose of this notebook:
- create training data (`.jsonl`) for a fine-tuned LLM
- fine-tuning will be done locally (on a MAC) using the [`mlx`](https://github.com/ml-explore)-library

Inspired by:
[https://apeatling.com/articles/simple-guide-to-local-llm-fine-tuning-on-a-mac-with-mlx/](https://apeatling.com/articles/simple-guide-to-local-llm-fine-tuning-on-a-mac-with-mlx/)

In [7]:
import pandas as pd

from ssrq_retro_lab.config import PROJECT_ROOT, ZG_DATA_ROOT
from ssrq_retro_lab.pipeline.templates.utils import render_template
from ssrq_retro_lab.repository.writer import JSONLWriter

In [6]:
ocr_training_df = pd.read_pickle(
    PROJECT_ROOT / "notebooks" / "pkl_cache" / "ocr_line_based_training.pkl"
)

In [14]:
prompts: list[str] = [
    render_template(
        "mistral_ocr_training_v1.jinja2", source=row["source"], target=row["target"]
    )
    for _, row in ocr_training_df.iterrows()
]

In [18]:
import random
from sklearn.model_selection import train_test_split

# Set the seed to 42 for reproducibility
random.seed(42)

random.shuffle(prompts)

test_prompts = prompts[:17]

train, validation = train_test_split(
    prompts[17:], test_size=0.2, shuffle=True, random_state=42
)

In [21]:
# Should have created 80% train, 20% validation and 17 test prompts
len(train), len(validation), len(test_prompts)

(480, 120, 17)

In [33]:
JSONLWriter(ZG_DATA_ROOT / "training_data" / "mistral_ocr" / "test.jsonl").write(
    content=[{"text": p} for p in test_prompts]
)

JSONLWriter(ZG_DATA_ROOT / "training_data" / "mistral_ocr" / "train.jsonl").write(
    content=[{"text": p} for p in train]
)

JSONLWriter(ZG_DATA_ROOT / "training_data" / "mistral_ocr" / "valid.jsonl").write(
    content=[{"text": p} for p in validation]
)

In [36]:
!python {str(PROJECT_ROOT / "lib/mlx_examples/lora/lora.py")} --train --model "mlx-community/Mistral-7B-Instruct-v0.2-4bit-mlx" --adapter-file {str(PROJECT_ROOT / "model" / "ssrq_mistral_ocr_adapter.npz")} --data {str(ZG_DATA_ROOT / "training_data" / "mistral_ocr")} --batch-size 2 --lora-layers 8 --iters 1000

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
Loading pretrained model
Fetching 5 files:   0%|                                   | 0/5 [00:00<?, ?it/s]
config.json: 100%|█████████████████████████| 2.26k/2.26k [00:00<00:00, 11.2MB/s][A
Fetching 5 files:  20%|█████▍                     | 1/5 [00:00<00:01,  3.43it/s]
tokenizer.json:   0%|                               | 0.00/1.80M [00:00<?, ?B/s][A

tokenizer_config.json: 100%|███████████████| 1.46k/1.46k [00:00<00:00, 8.49MB/s][A[A


special_tokens_map.json: 100%|█████████████████| 414/414 [00:00<00:00, 2.38MB/s][A[A
Fetching 5 files:  40%|██████████▊                | 2/5 [00:00<00:00,  4.30it/s]
tokenizer.json: 100%|██████████████████████| 1.80M/1.80M [00:00<00:00, 3.56MB/s][A
Fetching 5 files:  60%|████████████████▏          | 3/5 [00:00<00:00,  3.67it/s]
weights.00.safetensors:   0%|                       | 0.00/4.26

In [38]:
!python {str(PROJECT_ROOT / "lib/mlx_examples/lora/lora.py")} --test --model "mlx-community/Mistral-7B-Instruct-v0.2-4bit-mlx" --adapter-file {str(PROJECT_ROOT / "model" / "ssrq_mistral_ocr_adapter.npz")} --data {str(ZG_DATA_ROOT / "training_data" / "mistral_ocr")}

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
Loading pretrained model
Fetching 5 files: 100%|████████████████████████| 5/5 [00:00<00:00, 72817.78it/s]
Total parameters 1244.041M
Trainable parameters 1.704M
Loading datasets
Testing
Test loss 0.912, Test ppl 2.489.


In [43]:
f"{test_prompts[0].split('\n{')[0]}</s>"

'<s>[INST] You are a helpful research assistant with extremely good knowledge in scholarly editing. Your task is to correct text snippets extracted from a printed scholarly edition with OCR. Correct them without modernizing. Respond with the corrected text as a valid JSON object. Here is the text to correct:\nso si dem gotzhuse getan hant, und von alter an si bracht ist, als\n[/INST]</s>'

In [45]:
from mlx_lm import load, generate

model, tokenizer = load(
    "mlx-community/Mistral-7B-Instruct-v0.2-4bit-mlx",
    adapter_file=str(PROJECT_ROOT / "model" / "ssrq_mistral_ocr_adapter.npz"),
)

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

In [87]:
for test_prompt in test_prompts:
    prompt_without_result: str = f"{f"{test_prompt.split('\n{')[0]}</s>"}</s>"
    response = generate(
        model,
        tokenizer,
        prompt=prompt_without_result,
        verbose=True,
    )

Prompt: <s>[INST] You are a helpful research assistant with extremely good knowledge in scholarly editing. Your task is to correct text snippets extracted from a printed scholarly edition with OCR. Correct them without modernizing. Respond with the corrected text as a valid JSON object. Here is the text to correct:
so si dem gotzhuse getan hant, und von alter an si bracht ist, als
[/INST]</s></s>
<s> Question: Correct the following text without modernizing.
so si dem gotzhuse getan hant, und von alter an si bracht ist, als
-------------------------------------------------------------------
Corrected:
so si dem gotzhuse getan hant, und von alter an si bracht ist, als
[] So si dem gotzhuse getan hant, und von alter an si bracht ist, als
[/][
{
"cor
Prompt: 53.497 tokens-per-sec
Generation: 6.160 tokens-per-sec
Prompt: <s>[INST] You are a helpful research assistant with extremely good knowledge in scholarly editing. Your task is to correct text snippets extracted from a printed scholarly 

The cells above show that the fine-tuned mistral models never answers with valid JSON. The output is totally unusable. In other words: the fine-tuned model is not working. 