In [52]:
import pandas as pd
from dotenv import load_dotenv
import os
import requests
import json
import re
from pydantic import BaseModel
from openai import OpenAI
import os

load_dotenv()
fireworks_key = os.getenv('FIREWORKS_API_KEY')

splits = {'train': 'data/train-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}
train_df = pd.read_parquet(splits['train'])
test_df = pd.read_parquet(splits['test'])

In [53]:
q1_text = test_df['question'].iloc[0]
q1_text

"Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"

In [57]:
import pandas as pd
from openai import OpenAI
import json
import re
from pydantic import BaseModel
import time
from tqdm import tqdm
import os

# Define the output schema using Pydantic
class QAResult(BaseModel):
    question: str
    answer: str

def process_questions(test_df, num_questions=100, output_file="reasoning_traces.txt"):
    # Initialize the Fireworks client
    client = OpenAI(
        base_url="https://api.fireworks.ai/inference/v1",
        api_key=os.getenv("FIREWORKS_API_KEY"),
    )
    
    # Create new columns for reasoning and answers
    test_df['reasoning_trace'] = ''
    test_df['model_answer'] = ''
    
    # Open file for logging
    with open(output_file, 'w', encoding='utf-8') as f:
        # Process the specified number of questions
        for idx in tqdm(range(min(num_questions, len(test_df)))):
            try:
                question = test_df['question'].iloc[idx]
                
                # Log the question
                f.write(f"\nQuestion {idx + 1}:\n{question}\n")
                f.write("-" * 80 + "\n")
                
                # Construct the messages payload
                messages = [{"role": "user", "content": question}]
                
                # Make the API call to the model
                response = client.chat.completions.create(
                    model="accounts/fireworks/models/deepseek-r1",
                    messages=messages,
                    response_format={"type": "json_object", "schema": QAResult.model_json_schema()},
                    max_tokens=3000,
                )
                
                # Extract the content of the response
                response_content = response.choices[0].message.content
                
                # Extract the reasoning part
                reasoning_match = re.search(r"<think>(.*?)</think>", response_content, re.DOTALL)
                reasoning = reasoning_match.group(1).strip() if reasoning_match else "No reasoning provided."
                
                # Extract the JSON part
                json_match = re.search(r"</think>\s*(\{.*\})", response_content, re.DOTALL)
                json_str = json_match.group(1).strip() if json_match else "{}"
                
                # Parse the JSON string using model_validate_json
                qa_result = QAResult.model_validate_json(json_str)
                
                # Store in DataFrame
                test_df.at[idx, 'reasoning_trace'] = reasoning
                test_df.at[idx, 'model_answer'] = qa_result.answer
                
                # Log to file
                f.write("Reasoning:\n")
                f.write(reasoning + "\n")
                f.write("\nQA Result:\n")
                f.write(qa_result.model_dump_json(indent=4) + "\n")
                f.write("=" * 80 + "\n")
                
                # Add a small delay to avoid rate limiting
                time.sleep(0.5)
                
            except Exception as e:
                error_msg = f"Error processing question {idx + 1}: {str(e)}"
                print(error_msg)
                f.write(f"\nERROR: {error_msg}\n")
                f.write("=" * 80 + "\n")
                continue
    
    # Try to save the DataFrame
    try:
        test_df.to_parquet('processed_test_data.parquet')
        print("Successfully saved processed data to parquet file")
    except Exception as e:
        print(f"Error saving DataFrame: {str(e)}")
        print("Results are still available in the text file")
    
    return test_df

# Run the processing
processed_df = process_questions(test_df)

  3%|▎         | 3/100 [01:20<46:50, 28.97s/it]

Error processing question 3: 2 validation errors for QAResult
question
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing
answer
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing


 41%|████      | 41/100 [13:00<26:05, 26.53s/it]

Error processing question 41: 2 validation errors for QAResult
question
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing
answer
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing


 86%|████████▌ | 86/100 [27:18<06:01, 25.84s/it]

Error processing question 86: 2 validation errors for QAResult
question
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing
answer
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing


 94%|█████████▍| 94/100 [30:09<02:34, 25.74s/it]

Error processing question 94: 2 validation errors for QAResult
question
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing
answer
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing


100%|██████████| 100/100 [32:40<00:00, 19.60s/it]

Error processing question 100: 2 validation errors for QAResult
question
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing
answer
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing
Successfully saved processed data to parquet file





In [None]:
# save dataframe
processed_df.to_parquet('data/processed_test_data.parquet')