In [None]:
import json
import os
from http import HTTPStatus
import dashscope
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from retrying import retry

# System prompt (restored)
system_prompt = """ prompt    """

@retry(stop_max_attempt_number=5, wait_fixed=2000)
def call_api(messages, max_tokens=8192, timeout=60):
    try:
        response = dashscope.Generation.call(
            api_key='xxxxxx',
            model='xxxxx',
            messages=messages,
            result_format='message',
            max_tokens=max_tokens,
            temperature=0.0,
            timeout=timeout  # Note: Confirm if dashscope supports this
        )
        if response.status_code == HTTPStatus.OK:
            return response.output.choices[0].message.content.strip()
        else:
            raise Exception(f"API error: {response.code} - {response.message}")
    except Exception as e:
        print(f"API call failed: {e}")
        raise

def process_batch(batch_questions, batch_idx, max_input_length=58000):
    batch_text = "\n".join([f"Problem {i+1}: {q}" for i, q in enumerate(batch_questions)])
    
    if len(batch_text) > max_input_length:
        trimmed_batch = []
        current_length = 0
        for i, q in enumerate(batch_questions):
            problem_text = f"Problem {i+1}: {q}"
            if current_length + len(problem_text) + 1 <= max_input_length:
                trimmed_batch.append(q)
                current_length += len(problem_text) + 1
            else:
                print(f"Batch {batch_idx}: Trimmed to {len(trimmed_batch)} problems, input length {current_length}")
                break
        batch_questions = trimmed_batch
        batch_text = "\n".join([f"Problem {i+1}: {q}" for i, q in enumerate(batch_questions)])

    print(f"Batch {batch_idx}: Processing {len(batch_questions)} problems, {len(batch_text)} chars")
    messages = [{'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': batch_text}]
    try:
        response = call_api(messages)
        replies = [r.strip() for r in response.split("----") if r.strip()]
        print(f"Batch {batch_idx}: Received {len(replies)}/{len(batch_questions)} replies")
        return replies + [None] * (len(batch_questions) - len(replies)) if len(replies) < len(batch_questions) else replies
    except Exception as e:
        print(f"Batch {batch_idx}: Failed - {e}")
        return [None] * len(batch_questions)

def save_results(results, output_file_path):
    valid_results = [r for r in results if r is not None]
    with open(output_file_path, 'w', encoding='utf-8') as outfile:
        json.dump(valid_results, outfile, ensure_ascii=False, indent=4)
    print(f"Saved {len(valid_results)}/{len(results)} valid results")

def process_file(input_file_path, output_file_path, initial_batch_size=5, max_workers=4, test_limit=10):
    start_time = time.time()

    # Read input file
    with open(input_file_path, 'r', encoding='utf-8') as jsonfile:
        data = json.load(jsonfile)
    questions = [entry.get("problem", "") for entry in data if entry.get("problem", "")][:test_limit]
    if not questions:
        print("Error: No valid problems found")
        return

    # Dynamic batch sizing
    avg_length = sum(len(q) for q in questions) / len(questions)
    batch_size = max(1, min(initial_batch_size, int(58000 / (avg_length + 50))))
    batches = [questions[i:i + batch_size] for i in range(0, len(questions), batch_size)]
    print(f"Testing {len(questions)} problems, batch_size={batch_size}, batches={len(batches)}, workers={max_workers}")

    # Process batches
    results = [None] * len(questions)
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_batch = {executor.submit(process_batch, b, i + 1): (i, b) for i, b in enumerate(batches)}
        for future in as_completed(future_to_batch):
            batch_idx, batch_questions = future_to_batch[future]
            try:
                batch_replies = future.result()
                start_idx = batch_idx * batch_size
                for j, reply in enumerate(batch_replies):
                    idx = start_idx + j
                    if idx < len(results):
                        if reply:
                            try:
                                results[idx] = json.loads(reply)
                            except json.JSONDecodeError:
                                print(f"Batch {batch_idx}, Problem {idx + 1}: Invalid JSON - {reply[:100]}...")
                        else:
                            print(f"Batch {batch_idx}, Problem {idx + 1}: No reply received")
            except Exception as e:
                print(f"Batch {batch_idx} failed entirely: {e}")

    # Save results
    save_results(results, output_file_path)
    print(f"Processed in {time.time() - start_time:.2f} seconds")

if __name__ == "__main__":
    input_file_path = r"x:xxxxxx"
    output_file_path = r"x:xxxxx"
    process_file(input_file_path, output_file_path, initial_batch_size=5, max_workers=4, test_limit=100)
    print("Testing complete!") 