In [12]:
from ssrq_retro_lab.config import PROJECT_ROOT
from ssrq_retro_lab.repository import reader
from ssrq_retro_lab.repository import writer
from ssrq_retro_lab.train import data, messages

In [13]:
import json

txt_pdf_conversion_table = json.loads(
    reader.TextReader((PROJECT_ROOT / "data/ZG/txt_to_pdf.json")).read()
)

In [14]:
from fitz_new import Document

def get_page_text_from_pdf(pdf: Document, page: int) -> str:
    return pdf.load_page(page).get_textpage().extractText(sort=True)

In [15]:
volumes = [pdf for pdf in (PROJECT_ROOT / "data/ZG/pdf").glob("*.pdf")]
master_transcriptions = [txt for txt in (PROJECT_ROOT / "data/ZG/master").glob("*[0-9].txt")]

training_data: list[data.OpenAIDataset] = []
ocr_validation_data: list[data.OpenAIDataset] = []
splitting_validation_data: list[data.OpenAIDataset] = []
classification_validation_data: list[data.OpenAIDataset] = []

validation_indicies = [3, 5, 9]

The following cells will create different types of training data used for fine-tuning a custom GPT-model, which should act as a research assistant in the process (retro)digitizing the Collection of Swiss Law Sources. We will create training data for the following tasks:

1. OCR-Correction
2. Text segmentation
3. Text classification

The training data will be saved into the file `openai_training_data.jsonl`. For each type a ~10% split will be created as validation data and saved into seperate files.

In [16]:
for volume in volumes:
    doc = reader.PDFReader(volume).read()
    volume_name = volume.name.removesuffix(".pdf").replace(".", "_")
    transcriptions = [
        transcription
        for transcription in master_transcriptions
        if transcription.name.startswith(volume_name)
    ]

    for index, transcription in enumerate(transcriptions):
        page_number = int(
            txt_pdf_conversion_table[volume_name][
                transcription.name.removesuffix(".txt")
            ]
        )
        page_text = get_page_text_from_pdf(
            doc,
            page_number,
        )
        master_text = reader.TextReader(transcription).read()
        training_set = data.create_openai_dataset(
            system_role=messages.SYSTEM_ROLE,
            user_template=messages.USER_OCR_CORRECTION,
            user_text=page_text,
            assistant_template=messages.ASSISTANT_OCR_CORRECTION,
            assistant_text=master_text,
        )

        if index in validation_indicies:
            ocr_validation_data.append(training_set)
        else:
            training_data.append(training_set)

In [17]:
from pathlib import Path

for volume in volumes:
    doc = reader.PDFReader(volume).read()
    volume_name = volume.name.removesuffix(".pdf").replace(".", "_")
    transcriptions = [
        transcription
        for transcription in master_transcriptions
        if transcription.name.startswith(volume_name)
    ]

    for index, transcription in enumerate(transcriptions):
        page_number = int(
            txt_pdf_conversion_table[volume_name][
                transcription.name.removesuffix(".txt")
            ]
        )
        page_text = get_page_text_from_pdf(
            doc,
            page_number,
        )
        master_text = reader.TextReader(Path(str(transcription.absolute()).replace(".txt", "_splitted.txt"))).read()
        training_set = data.create_openai_dataset(
            system_role=messages.SYSTEM_ROLE,
            user_template=messages.USER_TEXT_SPLITTING,
            user_text=page_text,
            assistant_template=messages.ASSISTANT_TEXT_SPLITTING,
            assistant_text=master_text,
        )

        if index in validation_indicies:
            splitting_validation_data.append(training_set)
        else:
            training_data.append(training_set)

In [18]:
for volume in volumes:
    doc = reader.PDFReader(volume).read()
    volume_name = volume.name.removesuffix(".pdf").replace(".", "_")
    transcriptions = [
        transcription
        for transcription in master_transcriptions
        if transcription.name.startswith(volume_name)
    ]

    for index, transcription in enumerate(transcriptions):
        page_number = int(
            txt_pdf_conversion_table[volume_name][
                transcription.name.removesuffix(".txt")
            ]
        )
        page_text = get_page_text_from_pdf(
            doc,
            page_number,
        )
        master_text = reader.TextReader(Path(str(transcription.absolute()).replace(".txt", "_classified.txt"))).read()
        training_set = data.create_openai_dataset(
            system_role=messages.SYSTEM_ROLE,
            user_template=messages.USER_TEXT_CLASSIFICATION,
            user_text=page_text,
            assistant_template=messages.ASSISTANT_TEXT_CLASSIFIED,
            assistant_text=master_text,
        )

        if index in validation_indicies:
            classification_validation_data.append(training_set)
        else:
            training_data.append(training_set)

In [19]:
# save the training data as a jsonl file and each validation set
writer.JSONLWriter(
    PROJECT_ROOT / "data/ZG/openai_training_data.jsonl"
).write(content=[x.to_dict() for x in training_data])

writer.JSONLWriter(
    PROJECT_ROOT / "data/ZG/openai_ocr_validation.jsonl"
).write(content=[x.to_dict() for x in ocr_validation_data])

writer.JSONLWriter(
    PROJECT_ROOT / "data/ZG/openai_splitting_validation.jsonl"
).write(content=[x.to_dict() for x in splitting_validation_data])

writer.JSONLWriter(
    PROJECT_ROOT / "data/ZG/openai_classification_validation.jsonl"
).write(content=[x.to_dict() for x in classification_validation_data])

In [20]:
jsonl_training_data = reader.JSONLReader(
    PROJECT_ROOT / "data/ZG/openai_training_data.jsonl"
).read()

print(len(jsonl_training_data) == len(training_data))

True


In [21]:
from ssrq_retro_lab.train.validation import validate_openai_training_data

assert validate_openai_training_data(jsonl_training_data) is True

No errors found


In [22]:
from ssrq_retro_lab.train.stats import num_assistant_tokens_from_messages, num_tokens_from_messages,print_distribution

TOKEN_LIMIT = 16385

# Warnings and tokens counts
n_missing_system = 0
n_missing_user = 0
n_messages = []
convo_lens = []
assistant_message_lens = []

for ex in jsonl_training_data:
    messages = ex["messages"]
    if not any(message["role"] == "system" for message in messages):
        n_missing_system += 1
    if not any(message["role"] == "user" for message in messages):
        n_missing_user += 1
    n_messages.append(len(messages))
    convo_lens.append(num_tokens_from_messages(messages))
    assistant_message_lens.append(num_assistant_tokens_from_messages(messages))

print("Num examples missing system message:", n_missing_system)
print("Num examples missing user message:", n_missing_user)
print_distribution(n_messages, "num_messages_per_example")
print_distribution(convo_lens, "num_total_tokens_per_example")
print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
n_too_long = sum(l > TOKEN_LIMIT for l in convo_lens)
print(f"\n{n_too_long} examples may be over the {TOKEN_LIMIT} token limit, they will be truncated during fine-tuning")

Num examples missing system message: 0
Num examples missing user message: 0

#### Distribution of num_messages_per_example:
min / max: 3, 3
mean / median: 3.0, 3.0
p5 / p95: 3.0, 3.0

#### Distribution of num_total_tokens_per_example:
min / max: 1228, 2124
mean / median: 1724.9290780141844, 1724.0
p5 / p95: 1465.0, 1993.0

#### Distribution of num_assistant_tokens_per_example:
min / max: 570, 1011
mean / median: 789.3475177304964, 778.0
p5 / p95: 660.0, 931.0

0 examples may be over the 16385 token limit, they will be truncated during fine-tuning


In [23]:
# Pricing and default n_epochs estimate

TARGET_EPOCHS = 3
MIN_TARGET_EXAMPLES = 100
MAX_TARGET_EXAMPLES = 25000
MIN_DEFAULT_EPOCHS = 1
MAX_DEFAULT_EPOCHS = 25

n_epochs = TARGET_EPOCHS
n_train_examples = len(jsonl_training_data)
if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
    n_epochs = min(MAX_DEFAULT_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
    n_epochs = max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)

n_billing_tokens_in_dataset = sum(min(TOKEN_LIMIT, length) for length in convo_lens)
print(f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training")
print(f"By default, you'll train for {n_epochs} epochs on this dataset")
print(f"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens")

cost_per_100k_tokens = 0.80  # Cost for every 100,000 tokens
estimated_cost = ((n_epochs * n_billing_tokens_in_dataset) / 100000) * cost_per_100k_tokens
print(f"Estimated cost for fine-tuning: approximately ${estimated_cost:.2f}") #I added this for actual cost based on current pricing

Dataset has ~243215 tokens that will be charged for during training
By default, you'll train for 3 epochs on this dataset
By default, you'll be charged for ~729645 tokens
Estimated cost for fine-tuning: approximately $5.84
