# Generate Binary Classifications 

In [47]:
import sys
import toml
from openai import OpenAI
from pathlib import Path
import json
from typing import Dict, Any, List

import pandas as pd

import re

sys.path.append('..')

from api_utils import load_api_params

In [48]:
SECRETS_PATH = "../.secrets.toml" 

# Model Parameters (change if needed)
TEMPERATURE = 0.3
STREAM = False

MODEL = 'alibayram/medgemma'

# load the question-answer turns
df_input = pd.read_csv('../inputs/binary_class_2_eval.csv')


# Load API parameters and initialize client
API_CALL_PARAMS = load_api_params(SECRETS_PATH)
client = OpenAI(
    base_url=API_CALL_PARAMS['API_URL'],
    api_key=API_CALL_PARAMS['API_KEY']
)

In [49]:
def generate_completion(model: str, messages: List[Dict[str, str]]) -> str:
    response = client.chat.completions.create(
        model=model, 
        messages=messages,
        temperature=TEMPERATURE,
        stream=STREAM
    )
    return response.choices[0].message.content

def llm_process(prompt, conversation_turn):
    messages = [
        {"role": "system", "content": prompt},
        {"role": "user", "content": conversation_turn}
    ]
    try:
        return generate_completion(MODEL, messages)
    except Exception as e:
        raise Exception(f"Error generating completion: {e}")
    
def extract_json_from_markdown(text):
    # Remove markdown code block formatting
    # This pattern looks for ```json ... ```
    pattern = r'```json\s*(.*?)\s*```'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1)
    return text  # Return original if no markdown found

def create_conversation_turn(row):
    return f"PersonA: {row['personA_question']} PersonB: {row['personB_answer']}"


In [50]:
def process_row(row, prompt):
    """Process each row through the LLM and return results"""
    # Create conversation turn
    conversation_turn = create_conversation_turn(row)
    
    # Process through LLM
    llm_result = llm_process(prompt, conversation_turn)
    
    # Parse the JSON response
    try:
        json_text = extract_json_from_markdown(llm_result)
        parsed_result = json.loads(json_text)
        
        return {
            'turn_id': row['turn_id'],
            'original_question': row['personA_question'],
            'original_answer': row['personB_answer'],
            'conversation_turn': conversation_turn,
            'label': parsed_result['label'],
            'explanation': parsed_result['explanation']
        }
    except Exception as e:
        return {
            'turn_id': row['turn_id'],
            'original_question': row['personA_question'],
            'original_answer': row['personB_answer'],
            'conversation_turn': conversation_turn,
            'label': 'ERROR',
            'explanation': f'Failed to parse LLM response: {str(e)}'
        }

### Prompt A

In [51]:
with open('./prompts/binary_class_prompt_A.txt', 'r') as file:
    prompt_a = file.read()

In [52]:
results = []
for index, row in df_input.iterrows():
    result = process_row(row, prompt_a)
    results.append(result)
    print(f"Processed turn {result['turn_id']}: {result['label']}")

# Create DataFrame B with results
df_output_A = pd.DataFrame(results)
print(f"\nProcessed {len(df_output_A)} conversation turns")

# print(df_output_A.head())

df_output_A.to_csv('binary_class_judge_output_A.csv', index=False)

Processed turn 1: False
Processed turn 2: False
Processed turn 3: True
Processed turn 4: False
Processed turn 5: False
Processed turn 6: False
Processed turn 7: True
Processed turn 8: False
Processed turn 9: True
Processed turn 10: True

Processed 10 conversation turns


### Prompt B

In [53]:
with open('./prompts/binary_class_prompt_B.txt', 'r') as file:
    prompt_b = file.read()

In [54]:
results = []
for index, row in df_input.iterrows():
    result = process_row(row, prompt_b)
    results.append(result)
    print(f"Processed turn {result['turn_id']}: {result['label']}")

# Create DataFrame B with results
df_output_B = pd.DataFrame(results)
print(f"\nProcessed {len(df_output_B)} conversation turns")

# print(df_output_B.head())

df_output_B.to_csv('binary_class_judge_output_B.csv', index=False)

Processed turn 1: False
Processed turn 2: True
Processed turn 3: True
Processed turn 4: False
Processed turn 5: False
Processed turn 6: False
Processed turn 7: True
Processed turn 8: False
Processed turn 9: True
Processed turn 10: True

Processed 10 conversation turns
