In [None]:
import os

import json
import time
import numpy as np
import tiktoken # for token counting
from collections import defaultdict

from src.prompts import *
from src.hapi_elpers import sample_cases, clear_openai_cache

from openai import OpenAI

In [131]:
class TrainingDataPreparer():
    def __init__(self):
        self.training_data = None
        self.validation_data = None

    def prepare_data(self, data_paths: list[str], annotation_types: dict = annotation_types):
        """Prepare data in the format required by OpenAI's fine-tuning API"""
        data = []
        
        for case_path in data_paths:
            with open(case_path, encoding='utf-8') as f:
                file = json.load(f)
            
            prompt = f"""Annotate the below law case file according to the following 5 annotation types:
            {json.dumps(annotation_types, indent=2)}

            Take your time, be as thorough as possible, and combine all the annotations from a single annotation type into a list of comma-separated strings. Do not include sources. Annotations must be direct, unedited quotes from the case file.

            Always respond to the user in JSON format where the keys are the annotation types, and the value for each key is an array (list) of strings where each string is a separate annotation relevant to the given key. Include no other text in your response.

            Law case file to annotate:
            {json.dumps(file['text'], indent=2)}"""

            response = json.dumps(file['annotations'], indent=2)

            # Create the training example
            training_example = {
                "messages": [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": response}
                ]
            }
            data.append(training_example)
        
        return data

    def save_data(self, data_id: str):
        """Save training data in JSONL format as required by OpenAI"""
        if self.training_data is None or self.validation_data is None:
            print("Prepare training and validation data before saving")
            return
        
        if not os.path.exists(f"../data/finetuning/{data_id}"):
            os.makedirs(f"../data/finetuning/{data_id}")
        
        training_save_path = f"../data/finetuning/{data_id}/train.jsonl"
        with open(training_save_path,'w') as f:
            for example in self.training_data:
                f.write(json.dumps(example) + '\n')

        validation_save_path = f"../data/finetuning/{data_id}/val.jsonl"
        with open(validation_save_path,'w') as f:
            for example in self.validation_data:
                f.write(json.dumps(example) + '\n')

In [132]:
class Finetuner():
    def __init__(self, model: str):
        self.client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
        self.model = model
        self.training_file_id = None
        self.validation_file_id = None
        self.job_id = None
        self.finetuned_model = None

    def upload_files(self, data_id: str) -> str:
        """Upload training file to OpenAI"""
        with open(f"../data/finetuning/{data_id}/train.jsonl", "rb") as f:
            response = self.client.files.create(
                file=f,
                purpose="fine-tune"
            )
        self.training_file_id = response.id
        print(f"Training file uploaded successfully. File ID: {self.training_file_id}")

        with open(f"../data/finetuning/{data_id}/val.jsonl", "rb") as f:
            response = self.client.files.create(
                file=f,
                purpose="fine-tune"
            )
        self.validation_file_id = response.id
        print(f"Validation file uploaded successfully. File ID: {self.validation_file_id}")


    def create_job(self, n_epochs: int = 5) -> str:
        """Create and start a fine-tuning job"""
        if self.training_file_id is None:
            raise ValueError("No training file uploaded. Call upload_training_file first.")
            
        job = self.client.fine_tuning.jobs.create(
            training_file=self.training_file_id,
            validation_file = self.validation_file_id,
            model=self.model,
            hyperparameters={
                "n_epochs": n_epochs
            },
            seed=1
        )
        
        self.job_id = job.id
        print(f"Fine-tuning job created successfully. Job ID: {self.job_id}")
        return self.client, self.job_id

In [None]:
gt_folder = "../data/cases_json/"
training_data_paths, validation_data_paths, test_gtpaths = sample_cases(gt_folder, n_train = 100, n_val = 10, n_test = 10, seed=2)

preparer = TrainingDataPreparer()
preparer.training_data = preparer.prepare_data(training_data_paths)
preparer.validation_data = preparer.prepare_data(validation_data_paths)

data_id = "25.11.24_100cases"
preparer.save_data(data_id)
# np.save(f"../data/finetuning/{data_id}/test.npy", test_gtpaths)

#### data validation

In [135]:
data_path = f"../data/finetuning/{data_id}/train.jsonl"

# Load the dataset
with open(data_path, 'r', encoding='utf-8') as f:
    dataset = [json.loads(line) for line in f]

In [136]:
# Format error checks
format_errors = defaultdict(int)

for ex in dataset:
    if not isinstance(ex, dict):
        format_errors["data_type"] += 1
        continue
        
    messages = ex.get("messages", None)
    if not messages:
        format_errors["missing_messages_list"] += 1
        continue
        
    for message in messages:
        if "role" not in message or "content" not in message:
            format_errors["message_missing_key"] += 1
        
        if any(k not in ("role", "content", "name", "function_call", "weight") for k in message):
            format_errors["message_unrecognized_key"] += 1
        
        if message.get("role", None) not in ("system", "user", "assistant", "function"):
            format_errors["unrecognized_role"] += 1
            
        content = message.get("content", None)
        function_call = message.get("function_call", None)
        
        if (not content and not function_call) or not isinstance(content, str):
            format_errors["missing_content"] += 1
    
    if not any(message.get("role", None) == "assistant" for message in messages):
        format_errors["example_missing_assistant_message"] += 1

if format_errors:
    print("Found errors:")
    for k, v in format_errors.items():
        print(f"{k}: {v}")
else:
    print("No errors found")

No errors found


#### try to finetune

In [177]:
clear_openai_cache()

Cache Clear


In [178]:
finetuner = Finetuner(model='gpt-4o-mini-2024-07-18')
finetuner.upload_files(data_id=data_id)

Training file uploaded successfully. File ID: file-7UYmyv6nZxyYKNS2RHAxfL
Validation file uploaded successfully. File ID: file-4yQm5QHKGonqKMJFoTz45m


In [179]:
client, job_id = finetuner.create_job(n_epochs=5)

Fine-tuning job created successfully. Job ID: ftjob-eboXhRujBQFOLpe4ImYLrzNH


In [182]:
for i in client.fine_tuning.jobs.list_events(fine_tuning_job_id=job_id, limit=20).data:
    print(i.message)

The job has successfully completed
New fine-tuned model created
Step 99/500: training loss=0.14
Step 98/500: training loss=0.12
Step 97/500: training loss=0.06
Step 96/500: training loss=0.06
Step 95/500: training loss=0.08
Step 94/500: training loss=0.08
Step 93/500: training loss=0.08
Step 92/500: training loss=0.08
Step 91/500: training loss=0.07
Step 90/500: training loss=0.05, validation loss=0.05
Step 89/500: training loss=0.09
Step 88/500: training loss=0.13
Step 87/500: training loss=0.13
Step 86/500: training loss=0.07
Step 85/500: training loss=0.17
Step 84/500: training loss=0.12
Step 83/500: training loss=0.08
Step 82/500: training loss=0.12


In [184]:
ft_job = client.fine_tuning.jobs.retrieve(job_id)
ft_job

FineTuningJob(id='ftjob-eboXhRujBQFOLpe4ImYLrzNH', created_at=1732553004, error=Error(code=None, message=None, param=None), fine_tuned_model='ft:gpt-4o-mini-2024-07-18:university-of-oxford::AXWYSys7', finished_at=1732553327, hyperparameters=Hyperparameters(n_epochs=5, batch_size=1, learning_rate_multiplier=1.8), model='gpt-4o-mini-2024-07-18', object='fine_tuning.job', organization_id='org-ZqpsyAHhGKahLCkPupZx6uhn', result_files=[], seed=1, status='succeeded', trained_tokens=3709770, training_file='file-7UYmyv6nZxyYKNS2RHAxfL', validation_file='file-4yQm5QHKGonqKMJFoTz45m', estimated_finish=1732554201, integrations=[], user_provided_suffix=None)

In [185]:
np.save(f"../models/{job_id}.npy", job_id)

In [186]:
ft_job.trained_tokens * 3 * (1/1e6) * 5

55.646550000000005