In [2]:
import os
import re
import json
from glob import glob
from tqdm import tqdm
import pandas as pd
import pickle
import getpass
import tiktoken
import random
import numpy as np

enc = tiktoken.get_encoding("cl100k_base")
enc = tiktoken.encoding_for_model('gpt-4')

from openai import OpenAI

In [3]:
OPENAI_API_KEY = getpass.getpass("Enter your OpenAI API Key: ")
client = OpenAI(api_key=OPENAI_API_KEY)

In [4]:
def save_input_batch_file(prompts=None, batch_name=None, model='41'):
    if model == '4omini': gpt = 'gpt-4o-mini-2024-07-18'
    elif model == '4o': gpt = 'gpt-4o-2024-11-20'
    elif model == '41mini': gpt = 'gpt-4.1-mini-2025-04-14'
    elif model == '41': gpt = 'gpt-4.1-2025-04-14'

    print('Call ', gpt)
    k = 0
    batch_list = []
    for i, prompt in tqdm(enumerate(prompts)):
        tmp_input = {"custom_id": f"{batch_name}_{i}",
                     "method": "POST",
                     "url": "/v1/chat/completions",
                     "body": {"model": gpt,
                              "messages": prompt,
                              "max_tokens": 1024,
                              "temperature": 1.0,
                              "top_p": 1,
                              "frequency_penalty":0, "presence_penalty":0,
                             }}
    
        batch_list.append(tmp_input)
    
        if len(batch_list) >= 40000:
            with open(f"./finance-legal-mrc/{batch_name}_{k}.jsonl", 'w') as jsonl_file:
                for item in batch_list:
                    jsonl_file.write(json.dumps(item) + '\n')
            k += 1
            batch_list = []
    
    with open(f"./finance-legal-mrc/{batch_name}_{k}.jsonl", 'w') as jsonl_file:
        for item in batch_list:
            jsonl_file.write(json.dumps(item) + '\n')

In [5]:
def run_batch_api(client, batch_files, batch_info_path):
    # Load existing batch info if it exists
    batch_dict = {}
    batch_info_file = os.path.join(batch_info_path, "batch_info.json")
    if os.path.exists(batch_info_file):
        with open(batch_info_file, 'r') as f:
            batch_dict = json.load(f)
    
    for i, batch_name in tqdm(enumerate(batch_files), total=len(batch_files)):
        tmp = batch_name.split("/")[-1].split(".")[0]
        batch_input_file = client.files.create(
                        file=open(batch_name, "rb"),
                        purpose="batch")

        batch_input_file_id = batch_input_file.id    
        batch_obj = client.batches.create(
            input_file_id=batch_input_file_id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
            metadata={
                "cid": tmp
            }
        )
    
        # Update or add new batch info
        batch_dict[tmp] = {
            'input_file_id': batch_input_file_id,
            'batch_api_obj_id': batch_obj.id
        }

    with open(batch_info_file, 'w') as f:
        json.dump(batch_dict, f)

    return batch_dict

In [6]:
def batch_api_update(batch_info_path, client):
    if os.path.exists(os.path.join(batch_info_path, "batch_info.json")):
        with open(os.path.join(batch_info_path, "batch_info.json"), "r", encoding="utf-8") as file:
            batch_dict = json.load(file)
            
    c = 0
    for k in batch_dict.keys():
        try:
            status = client.batches.retrieve(batch_dict[k]['batch_api_obj_id']).status
        
            if status == 'completed':
                print(k, " is completed")
                output_file_id = client.batches.retrieve(batch_dict[k]['batch_api_obj_id']).output_file_id
                # Only update output_file_id if it's not already set
                if 'output_file_id' not in batch_dict[k] or batch_dict[k]['output_file_id'] is None:
                    batch_dict[k]['output_file_id'] = output_file_id
            else:
                print(k, f" is {status}")
                c += 1
                # Only set output_file_id to None if it's not already set
                if 'output_file_id' not in batch_dict[k]:
                    batch_dict[k]['output_file_id'] = None
        except: pass
    
    with open(os.path.join(batch_info_path, "batch_info.json"), 'w') as f:
        json.dump(batch_dict ,f)

    if c == 0: print("RUN COMPLTED")

### Run Batch API

In [6]:
prompt_path = f'./finfairnessQA_prompt.jsonl'
prompts = []
with open(prompt_path, 'r') as f:
    for line in f:
        prompts.append(json.loads(line.strip()))
prompts = prompts[0]

In [7]:
prompt_path = f'./finfairnessQA_prompt_w_g.jsonl'
prompts_w_g = []
with open(prompt_path, 'r') as f:
    for line in f:
        prompts_w_g.append(json.loads(line.strip()))
prompts_w_g = prompts_w_g[0]

In [8]:
save_input_batch_file(prompts=prompts, batch_name=f'finfairnessqa_task', model='4o')
save_input_batch_file(prompts=prompts_w_g, batch_name=f'finfairnessqa_task_w_g', model='4o')

Call  gpt-4o-2024-11-20


100it [00:00, 476084.45it/s]


Call  gpt-4o-2024-11-20


100it [00:00, 557012.48it/s]


In [9]:
batch_files = glob(f"./finance-legal-mrc/*finfairnessqa_task*.jsonl")
print(batch_files)

['./finance-legal-mrc/finfairnessqa_task_w_g_0.jsonl', './finance-legal-mrc/finfairnessqa_task_0.jsonl']


In [10]:
batch_info_path = "./finance-legal-mrc"
run_batch_api(client, batch_files, batch_info_path)

100%|██████████| 2/2 [00:03<00:00,  1.55s/it]


{'finfairnessqa_task_w_g_0': {'input_file_id': 'file-FUszKbCnrR6ayEv3NU2svB',
  'batch_api_obj_id': 'batch_68430ca646988190bef5f0284ee445cc'},
 'finfairnessqa_task_0': {'input_file_id': 'file-XTDLb9zwZokCQuQEQzAk1h',
  'batch_api_obj_id': 'batch_68430ca7739c8190839a23fa85e53440'}}

In [11]:
batch_api_update(batch_info_path, client)

finfairnessqa_task_w_g_0  is in_progress
finfairnessqa_task_0  is in_progress


## Call Response

In [7]:
def load_output_files(output_file_id):
    responses = []
    output_response = client.files.content(output_file_id)
    for i, r in tqdm(enumerate(output_response.iter_lines())):
        res = json.loads(r)
        responses.append(res['response']['body']['choices'][0]['message']['content'])
    return responses

In [8]:
batch_info_path = "./finance-legal-mrc"
batch_api_update(batch_info_path, client)

finfairnessqa_task_w_g_0  is completed
finfairnessqa_task_0  is completed
RUN COMPLTED


In [9]:
with open(os.path.join(batch_info_path, "batch_info.json"), 'r') as f:
    batch_list = json.load(f)
{k: v for k, v in batch_list.items() if v['output_file_id'] is not None}

{'finfairnessqa_task_w_g_0': {'input_file_id': 'file-FUszKbCnrR6ayEv3NU2svB',
  'batch_api_obj_id': 'batch_68430ca646988190bef5f0284ee445cc',
  'output_file_id': 'file-H2yJLg6G46VYsnaxKu6DCj'},
 'finfairnessqa_task_0': {'input_file_id': 'file-XTDLb9zwZokCQuQEQzAk1h',
  'batch_api_obj_id': 'batch_68430ca7739c8190839a23fa85e53440',
  'output_file_id': 'file-964uctD1EAX3jgAe2JJo5N'}}

In [32]:
prompt_title = 'finfairnessqa_task'
prompt_path = f'./finfairnessQA_prompt.jsonl'
input_prompts = []
with open(prompt_path, 'r') as f:
    for line in f:
        input_prompts.append(json.loads(line.strip()))

preds = load_output_files(batch_list[f'{prompt_title}_0']['output_file_id'])
preds = ['답변거부' if p not in ['참', '거짓'] else p for p in preds]

100it [00:00, 26119.72it/s]


## QA task result analysis

In [34]:
df = pd.read_csv("./bias_qa.csv")

In [35]:
df['response'] = preds
df['Acc'] = (df['정답'] == df['response']).astype(int)
accuracy = (df['정답'] == df['response']).mean()

In [36]:
print(f"Accuracy: {accuracy:.4f}")
# f- : 금융분야 지식 QA
f_accuracy = (df[df.Index.str.startswith('f-')]['정답'] == df[df.Index.str.startswith('f-')]['response']).mean()

print(f"Financial QA Accuracy: {f_accuracy:.4f}")

# b- : 편향성 QA
b_accuracy = (df[df.Index.str.startswith('b-')]['정답'] == df[df.Index.str.startswith('b-')]['response']).mean()
print(f"Bias QA Accuracy: {b_accuracy:.4f}")

Accuracy: 0.3500
Financial QA Accuracy: 0.5645
Bias QA Accuracy: 0.0000
