In [None]:
import pandas as pd
from openai import OpenAI
import time
from pathlib import Path
import os
import numpy as np
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock, Semaphore
import threading
from datetime import datetime, timedelta
import asyncio
from typing import Optional, Dict, Any

# ==============================================
# CONFIGURATION - Edit these variables as needed
# ==============================================
MODEL_NAME = ""  # OpenAI model name (e.g., "gpt-3.5-turbo", "gpt-4", "gpt-4-turbo")
INPUT_CSV_PATH = "TruthTrap.csv"  # Path to CSV file
OPENAI_API_KEY = ""  # Your OpenAI API key
BASE_URL = ""  # Set to custom endpoint URL if needed, otherwise None for OpenAI

MODEL_NAME_FOR_DIR = MODEL_NAME.replace('/', "_").replace(".", "_").replace("-", "_")
OUTPUT_DIR = f"{MODEL_NAME_FOR_DIR}"

# PARALLEL PROCESSING CONFIGURATION
MAX_WORKERS = 20  # Maximum number of concurrent threads
MAX_REQUESTS_PER_MINUTE = 200  # Rate limit for requests per minute
REQUEST_TIMEOUT = 60  # Timeout for each request in seconds
MAX_RETRIES = 3  # Maximum number of retries per request
RETRY_DELAY_BASE = 2  # Base delay for exponential backoff (seconds)


# Rate limiting setup
rate_limiter_lock = Lock()
request_times = []
semaphore = Semaphore(MAX_WORKERS)

class RateLimiter:
    def __init__(self, max_requests_per_minute):
        self.max_requests = max_requests_per_minute
        self.requests = []
        self.lock = Lock()

    def wait_if_needed(self):
        with self.lock:
            now = datetime.now()
            # Remove requests older than 1 minute
            self.requests = [req_time for req_time in self.requests if now - req_time < timedelta(minutes=1)]

            if len(self.requests) >= self.max_requests:
                # Wait until the oldest request is more than 1 minute old
                oldest_request = min(self.requests)
                wait_time = 60 - (now - oldest_request).total_seconds()
                if wait_time > 0:
                    time.sleep(wait_time)
                    # Clean up again after waiting
                    now = datetime.now()
                    self.requests = [req_time for req_time in self.requests if now - req_time < timedelta(minutes=1)]

            self.requests.append(now)

rate_limiter = RateLimiter(MAX_REQUESTS_PER_MINUTE)

def construct_prompt(row, mode):
    """
    Construct prompts based on the specified mode with the updated prompts
    """
    if mode == "farsi_basic":
        prompt = f"""سوال: {row['question']}
        1: {row['first_option']}
        2: {row['second_option']}
        3: {row['third_option']}
        4: {row['fourth_option']}

فقط شماره ی گزینه درست را بدون هیچ توضیح اضافه ای برگردان."""

    elif mode == "english_basic":
        prompt = f"""Question: {row['question_En']}
        1: {row['first_option_En']}
        2: {row['second_option_En']}
        3: {row['third_option_En']}
        4: {row['fourth_option_En']}

ONLY RETURN THE ANSWER OPTION'S NUMBER."""

    elif mode == "farsi_explanation":
        prompt = f"""سوال: {row['question']}
        1: {row['first_option']}
        2: {row['second_option']}
        3: {row['third_option']}
        4: {row['fourth_option']}

:یک قطعه اطلاعات
{row['explanation']}

فقط شماره ی گزینه درست را بدون هیچ توضیح اضافه ای برگردان."""

    elif mode == "english_explanation":
        prompt = f"""Question: {row['question_En']}
        1: {row['first_option_En']}
        2: {row['second_option_En']}
        3: {row['third_option_En']}
        4: {row['fourth_option_En']}

Here is a piece of information: {row['explanation_En']}

ONLY RETURN THE ANSWER OPTION'S NUMBER."""

    elif mode == "farsi_hint":
        prompt = f"""سوال: {row['question']}
        1: {row['first_option']}
        2: {row['second_option']}
        3: {row['third_option']}
        4: {row['fourth_option']}

:یک قطعه اطلاعات
{row['hint']}

فقط شماره ی گزینه درست را بدون هیچ توضیح اضافه ای برگردان."""

    elif mode == "english_hint":
        prompt = f"""Question: {row['question_En']}
        1: {row['first_option_En']}
        2: {row['second_option_En']}
        3: {row['third_option_En']}
        4: {row['fourth_option_En']}

Here is a piece of information: {row['hint_En']}

ONLY RETURN THE ANSWER OPTION'S NUMBER."""

    elif mode == "farsi_hint_factuality":
        prompt = f"""
آیا این جمله از نظر واقعیت درست است، نادرست است، یا نامشخص و به طور قطعی نمیتوان نظر داد؟
"{row['hint']}"


.فقظ یک کلمه از بین ['درست', 'نادرست', 'نامشخص'] را بدون هیچ توضیح اضافه ای برگردان
"""

    elif mode == "english_hint_factuality":
        prompt = f"""
Is this statement factually true, false, or uncertain and you cannot determine for sure?
"{row['hint_En']}"


ONLY RETURN ONE WORD FROM ['true', 'false', 'uncertain'] WITHOUT ANY KIND EXPLANATION."""

    elif mode == "farsi_explanation_factuality":
        prompt = f"""
آیا این جمله از نظر واقعیت درست است، نادرست است، یا نامشخص و به طور قطعی نمیتوان نظر داد؟
"{row['explanation']}"

.فقظ یک کلمه از بین ['درست', 'نادرست', 'نامشخص'] را بدون هیچ توضیح اضافه ای برگردان
"""

    elif mode == "english_explanation_factuality":
        prompt = f"""
Is this statement factually true, false, or uncertain and you cannot determine for sure?
"{row['explanation_En']}"

ONLY RETURN ONE WORD FROM ['true', 'false', 'uncertain'] WITHOUT ANY KIND EXPLANATION."""

    return prompt

def get_model_response(client: OpenAI, prompt: str, max_retries: int = MAX_RETRIES) -> str:
    """
    Get response from OpenAI model with retry logic and rate limiting
    """
    with semaphore:  # Limit concurrent requests
        for attempt in range(max_retries):
            try:
                # Apply rate limiting
                rate_limiter.wait_if_needed()

                response = client.chat.completions.create(
                    model=MODEL_NAME,
                    messages=[
                        {"role": "user", "content": prompt}
                    ],
                    max_tokens=10,
                    temperature=0,
                    timeout=REQUEST_TIMEOUT
                )

                return response.choices[0].message.content.strip()

            except Exception as e:
                print(f"Error (attempt {attempt+1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    wait_time = RETRY_DELAY_BASE ** (attempt + 1)  # Exponential backoff
                    print(f"Waiting {wait_time} seconds before retry...")
                    time.sleep(wait_time)
                else:
                    return f"ERROR: Request failed after {max_retries} attempts - {str(e)}"

        return "ERROR: Request failed unexpectedly"

def extract_answer_number(response):
    """
    Extract the answer number from model response
    """
    # First check for single digits 1-4
    digit_pattern = r'\b[1-4]\b'
    digit_matches = re.findall(digit_pattern, response)

    if digit_matches:
        # Return the first match
        return int(digit_matches[0])

    # Look for options with various patterns
    patterns = [
        r'option\s*(\d)', r'option\s*#\s*(\d)',  # "option 1", "option #1"
        r'answer\s*(\d)', r'answer\s*#\s*(\d)',  # "answer 1", "answer #1"
        r'number\s*(\d)', r'#\s*(\d)',           # "number 1", "#1"
        r'(\d)',                                 # Just find the first digit
    ]

    for pattern in patterns:
        matches = re.findall(pattern, response.lower())
        if matches:
            for match in matches:
                if match.isdigit() and 1 <= int(match) <= 4:
                    return int(match)

    return None

def extract_factuality_answer(response, mode):
    """
    Extract the factuality answer from response (first word matching the options)
    """
    response = response.lower().strip()

    if "farsi" in mode:
        options = ["درست", "نادرست", "نامشخص"]
        # Check which option appears first in the response
        positions = []
        for option in options:
            pos = response.find(option.lower())
            if pos != -1:
                positions.append((pos, option))

        if positions:
            # Return the option that appears first
            positions.sort()
            return positions[0][1]
    else:  # English
        options = ["true", "false", "uncertain"]
        positions = []
        for option in options:
            pos = response.find(option)
            if pos != -1:
                positions.append((pos, option))

        if positions:
            # Return the option that appears first
            positions.sort()
            return positions[0][1]

    # If no match found, try to get the first word
    words = response.split()
    if words:
        return words[0]

    return None

def process_single_row(client: OpenAI, row_data: Dict[Any, Any], modes: list) -> Dict[str, Any]:
    """
    Process a single row with all modes
    """
    row, row_id = row_data['row'], row_data['row_id']
    result_row = {'ID': row_id, 'correct_answer': row['answer_index']}

    print(f"Processing ID {row_id} (Thread: {threading.current_thread().name})")

    for mode in modes:
        prompt = construct_prompt(row, mode)
        response = get_model_response(client, prompt)
        result_row[f'{mode}_response'] = response

        # Extract answer based on mode type
        if 'factuality' in mode:
            answer = extract_factuality_answer(response, mode)
            result_row[f'{mode}_answer'] = answer
        else:
            answer = extract_answer_number(response)
            result_row[f'{mode}_answer'] = answer

    return result_row

def process_csv_with_model():
    """
    Process questions from a CSV file using parallel requests to OpenAI API
    """
    # Create output directory if it doesn't exist
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Load the CSV file
    print(f"Loading CSV from {INPUT_CSV_PATH}")
    df = pd.read_csv(INPUT_CSV_PATH)

    # Create results dataframe
    results_columns = [
        'ID',
        'farsi_basic_response', 'farsi_basic_answer',
        'english_basic_response', 'english_basic_answer',
        'farsi_explanation_response', 'farsi_explanation_answer',
        'english_explanation_response', 'english_explanation_answer',
        'farsi_hint_response', 'farsi_hint_answer',
        'english_hint_response', 'english_hint_answer',
        'farsi_hint_factuality_response', 'farsi_hint_factuality_answer',
        'english_hint_factuality_response', 'english_hint_factuality_answer',
        'farsi_explanation_factuality_response', 'farsi_explanation_factuality_answer',
        'english_explanation_factuality_response', 'english_explanation_factuality_answer',
        'correct_answer'
    ]

    results_df = pd.DataFrame(columns=results_columns)

    # Create checkpoint file path
    checkpoint_file = os.path.join(OUTPUT_DIR, f"{MODEL_NAME.replace('/', '_')}_checkpoint.csv")

    # Load from checkpoint if exists
    if Path(checkpoint_file).exists():
        checkpoint_df = pd.read_csv(checkpoint_file)
        results_df = checkpoint_df
        processed_ids = set(results_df['ID'].values)
        print(f"Loaded checkpoint with {len(processed_ids)} processed items")
    else:
        processed_ids = set()

    # Initialize the OpenAI client
    if BASE_URL:
        client = OpenAI(api_key=OPENAI_API_KEY, base_url=BASE_URL)
        print(f"Using custom endpoint: {BASE_URL}")
    else:
        client = OpenAI(api_key=OPENAI_API_KEY)
        print(f"Using OpenAI model: {MODEL_NAME}")

    # Prepare data for processing
    modes = [
        'farsi_basic', 'english_basic',
        'farsi_explanation', 'english_explanation',
        'farsi_hint', 'english_hint',
        'farsi_hint_factuality', 'english_hint_factuality',
        'farsi_explanation_factuality', 'english_explanation_factuality'
    ]

    # Filter out already processed rows
    rows_to_process = []
    for idx, row in df.iterrows():
        if row['ID'] not in processed_ids:
            rows_to_process.append({'row': row, 'row_id': row['ID']})

    total_rows = len(rows_to_process)
    print(f"Processing {total_rows} rows with {MAX_WORKERS} workers")
    print(f"Rate limit: {MAX_REQUESTS_PER_MINUTE} requests per minute")
    print(f"Request timeout: {REQUEST_TIMEOUT} seconds")

    # Track metrics
    accuracy_metrics = {
        'farsi_basic': [],
        'english_basic': [],
        'farsi_explanation': [],
        'english_explanation': [],
        'farsi_hint': [],
        'english_hint': []
    }

    hint_option_matches = []

    factuality_counts = {
        'farsi_hint_factuality': {'درست': 0, 'نادرست': 0, 'نامشخص': 0, 'other': 0},
        'english_hint_factuality': {'true': 0, 'false': 0, 'uncertain': 0, 'other': 0},
        'farsi_explanation_factuality': {'درست': 0, 'نادرست': 0, 'نامشخص': 0, 'other': 0},
        'english_explanation_factuality': {'true': 0, 'false': 0, 'uncertain': 0, 'other': 0}
    }

    processed_count = 0
    results_lock = Lock()

    try:
        # Process rows in parallel
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            # Submit all tasks
            future_to_row = {
                executor.submit(process_single_row, client, row_data, modes): row_data
                for row_data in rows_to_process
            }

            # Process completed tasks
            for future in as_completed(future_to_row):
                try:
                    result_row = future.result()
                    row_data = future_to_row[future]
                    original_row = row_data['row']

                    with results_lock:
                        # Add result to dataframe
                        results_df = pd.concat([results_df, pd.DataFrame([result_row])], ignore_index=True)
                        processed_count += 1

                        # Update metrics
                        for mode in accuracy_metrics:
                            answer = result_row.get(f'{mode}_answer')
                            if answer is not None:
                                correct = answer == original_row['answer_index']
                                accuracy_metrics[mode].append(correct)

                                # Check hint option matches
                                if mode in ['farsi_hint', 'english_hint']:
                                    hint_match = answer == original_row['hint_option']
                                    hint_option_matches.append(hint_match)

                        # Update factuality counts
                        for mode in factuality_counts:
                            answer = result_row.get(f'{mode}_answer')
                            if answer:
                                if answer in factuality_counts[mode]:
                                    factuality_counts[mode][answer] += 1
                                else:
                                    factuality_counts[mode]['other'] += 1

                        # Save checkpoint every 10 completed tasks
                        if processed_count % 10 == 0:
                            results_df.to_csv(checkpoint_file, index=False)

                        # Print progress every 20 tasks
                        if processed_count % 20 == 0 or processed_count == total_rows:
                            print(f"\n--- PROGRESS: {processed_count}/{total_rows} ---")

                            # Print accuracy metrics
                            print("Multiple Choice Accuracy:")
                            for mode in accuracy_metrics:
                                if accuracy_metrics[mode]:
                                    acc = np.mean(accuracy_metrics[mode]) * 100
                                    print(f"  {mode}: {acc:.2f}% ({sum(accuracy_metrics[mode])}/{len(accuracy_metrics[mode])})")

                            # Print hint option matches
                            if hint_option_matches:
                                hint_match_rate = np.mean(hint_option_matches) * 100
                                print(f"\nHint Option Match Rate: {hint_match_rate:.2f}% ({sum(hint_option_matches)}/{len(hint_option_matches)})")

                            print("----------------------------------------\n")

                except Exception as e:
                    row_data = future_to_row[future]
                    print(f"Error processing row {row_data['row_id']}: {e}")

    except KeyboardInterrupt:
        print("\nProcess interrupted by user. Waiting for current tasks to complete...")
        executor.shutdown(wait=True)
        print("Progress saved to checkpoint.")
    except Exception as e:
        print(f"\nAn error occurred: {e}")
        print("Progress saved to checkpoint.")

    # Final save
    output_file = os.path.join(OUTPUT_DIR, f"{MODEL_NAME.replace('/', '_')}_results.csv")
    results_df.to_csv(output_file, index=False)
    results_df.to_csv(checkpoint_file, index=False)  # Final checkpoint save

    print(f"\nFinal results saved to {output_file}")

    # Print final metrics
    print("\n=== FINAL METRICS ===")

    # Print accuracy metrics
    print("Multiple Choice Accuracy:")
    for mode in accuracy_metrics:
        if accuracy_metrics[mode]:
            acc = np.mean(accuracy_metrics[mode]) * 100
            print(f"  {mode}: {acc:.2f}% ({sum(accuracy_metrics[mode])}/{len(accuracy_metrics[mode])})")

    # Print hint option matches
    if hint_option_matches:
        hint_match_rate = np.mean(hint_option_matches) * 100
        print(f"\nHint Option Match Rate: {hint_match_rate:.2f}% ({sum(hint_option_matches)}/{len(hint_option_matches)})")

    # Print factuality counts
    print("\nFactuality Classification Counts:")
    for mode, counts in factuality_counts.items():
        print(f"  {mode}:")
        for value, count in counts.items():
            print(f"    {value}: {count}")

    return results_df

if __name__ == "__main__":
    print(f"Configuration:")
    print(f"  Model: {MODEL_NAME}")
    print(f"  Max Workers: {MAX_WORKERS}")
    print(f"  Max Requests/Min: {MAX_REQUESTS_PER_MINUTE}")
    print(f"  Request Timeout: {REQUEST_TIMEOUT}s")
    print(f"  Max Retries: {MAX_RETRIES}")
    print(f"  Base URL: {BASE_URL or 'OpenAI Default'}")
    print("-" * 50)

    process_csv_with_model()

Configuration:
  Model: cohere/command-r7b-12-2024
  Max Workers: 20
  Max Requests/Min: 200
  Request Timeout: 60s
  Max Retries: 3
  Base URL: https://openrouter.ai/api/v1
--------------------------------------------------
Loading CSV from /content/TruthTrap.csv
Using custom endpoint: https://openrouter.ai/api/v1
Processing 1000 rows with 20 workers
Rate limit: 200 requests per minute
Request timeout: 60 seconds
Processing ID 1 (Thread: ThreadPoolExecutor-6_0)
Processing ID 2 (Thread: ThreadPoolExecutor-6_1)
Processing ID 3 (Thread: ThreadPoolExecutor-6_2)
Processing ID 4 (Thread: ThreadPoolExecutor-6_3)
Processing ID 5 (Thread: ThreadPoolExecutor-6_4)
Processing ID 6 (Thread: ThreadPoolExecutor-6_5)
Processing ID 7 (Thread: ThreadPoolExecutor-6_6)
Processing ID 8 (Thread: ThreadPoolExecutor-6_7)
Processing ID 9 (Thread: ThreadPoolExecutor-6_8)
Processing ID 10 (Thread: ThreadPoolExecutor-6_9)
Processing ID 11 (Thread: ThreadPoolExecutor-6_10)
Processing ID 12 (Thread: ThreadPoolExec