In [None]:
import json
import os
import time
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple

from openai import OpenAI
from pydantic import BaseModel
from dotenv import load_dotenv
from texttools.batch_manager import SimpleBatchManager

In [2]:
##################### best practice for connecting without error ####################

# 1- using a proxy
# 2- running the code on VPS

# the first option is better, the data will be locally saved if anything went wrong

# Configurations 

In [None]:
# --- Configuration for batch ---
class BatchConfig:
    MAX_BATCH_SIZE = 1000  # Number of items per batch part
    MAX_TOTAL_TOKENS = 2000000  # Max total tokens for all parts
    CHARS_PER_TOKEN = 2.7
    PROMPT_TOKEN_MULTIPLIER = 1000  # As in original code
    BASE_OUTPUT_DIR = "batch_results"

# Helper Functions 

In [None]:
def data_for_batch(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Converts raw data to the required batch format: [{"id": int, "content": str}, ...]
    """
    result = []
    for idx, item in enumerate(data):
        if isinstance(item, dict) and "content" in item:
            result.append({"id": item.get("id", idx), "content": item["content"]})
        elif isinstance(item, str):
            result.append({"id": idx, "content": item})
        else:
            raise ValueError(f"Invalid data item at index {idx}: {item}")
    return result

In [None]:

def parsing_output(part_idx: int, output_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Processes the output from the batch manager. Here, just returns the data as-is.
    Extend as needed for your use case.
    """
    return output_data

# setup output structure

In [None]:
# you should implement this however you want
# the value for each key can be bool, integer, string or anything
# the model will theoratically obey this structure

# --- Output Model Example ---
class OutputData(BaseModel):
    desired_output: str

# setup BatchManager with BatchJobRunner

In [None]:
class BatchJobRunner:
    def __init__(self, 
                 system_prompt: str, 
                 job_name: str, 
                 input_data_path: str, 
                 output_data_path: str,
                 model: str = "gpt-4.1-mini",
                 output_model=OutputData):
        self.config = BatchConfig()
        self.system_prompt = system_prompt
        self.job_name = job_name
        self.input_data_path = input_data_path
        self.output_data_path = output_data_path
        self.model = model
        self.output_model = output_model
        self.manager = self._init_manager()
        self.data: List[Dict[str, Any]] = []
        self.parts: List[List[Dict[str, Any]]] = []
        self._load_data()
        self._partition_data()
        Path(self.config.BASE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

    def _init_manager(self) -> SimpleBatchManager:
        load_dotenv()
        api_key = os.getenv('OPENAI_API_KEY')
        client = OpenAI(api_key=api_key)
        return SimpleBatchManager(
            client=client,
            model=self.model,
            prompt_template=self.system_prompt,
            output_model=self.output_model
        )

    def _load_data(self):
        with open(self.input_data_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        self.data = data_for_batch(data)

    def _partition_data(self):
        total_length = sum(len(item["content"]) for item in self.data)
        prompt_length = len(self.system_prompt)
        total = total_length + (prompt_length * len(self.data))
        calculation = total / self.config.CHARS_PER_TOKEN
        print(f"Total chars: {total_length}, Prompt chars: {prompt_length}, Total: {total}, Tokens: {calculation}")
        if calculation < self.config.MAX_TOTAL_TOKENS:
            self.parts = [self.data]
        else:
            # Partition into chunks of MAX_BATCH_SIZE
            self.parts = [
                self.data[i:i + self.config.MAX_BATCH_SIZE]
                for i in range(0, len(self.data), self.config.MAX_BATCH_SIZE)
            ]
        print(f"Data split into {len(self.parts)} part(s)")

    def run(self):
        for idx, part in enumerate(self.parts):
            part_job_name = f"{self.job_name}_part_{idx+1}" if len(self.parts) > 1 else self.job_name
            print(f"\n--- Processing part {idx+1}/{len(self.parts)}: {part_job_name} ---")
            self._process_part(part, part_job_name, idx)

    def _process_part(self, part: List[Dict[str, Any]], part_job_name: str, part_idx: int):
        while True:
            command = input("Enter command (1.start, 2.check, 3.fetch): ").strip().lower()
            if command in ["1", "start"]:
                self.manager.start(part, job_name=part_job_name)
                print("Started batch job.")
                time.sleep(1)
            elif command in ["2", "check"]:
                status = self.manager.check_status(job_name=part_job_name)
                print(f"Status: {status}")
                time.sleep(5)
                if status == "completed":
                    print("Job completed. You can now fetch results.")
                elif status == "failed":
                    print("Job failed. Clearing state.")
                    self.manager._clear_state(part_job_name)
            elif command in ["3", "fetch"]:
                output_data, log = self.manager.fetch_results(job_name=part_job_name, save=True, remove_cache=False)
                output_data = parsing_output(part_idx, output_data)
                self._save_results(output_data, log, part_idx)
                print("Fetched and saved results for this part.")
                break
            else:
                print("Invalid command. Please enter 1, 2, or 3.")

    def _save_results(self, output_data: List[Dict[str, Any]], log: List[Any], part_idx: int):
        part_suffix = f"_part_{part_idx+1}" if len(self.parts) > 1 else ""
        result_path = Path(self.config.BASE_OUTPUT_DIR) / f"{Path(self.output_data_path).stem}{part_suffix}.json"
        with open(result_path, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, ensure_ascii=False, indent=4)
        if log:
            log_path = Path(self.config.BASE_OUTPUT_DIR) / f"{Path(self.output_data_path).stem}{part_suffix}_log.json"
            with open(log_path, 'w', encoding='utf-8') as f:
                json.dump(log, f, ensure_ascii=False, indent=4)

# start the Job

In [None]:
def main():
    print("=== Batch Job Runner ===")
    system_prompt = input("Enter system prompt: ").strip()
    job_name = input("Enter job name: ").strip()
    input_data_path = input("Enter input data path (JSON): ").strip()
    output_data_path = input("Enter output data path (JSON): ").strip()
    runner = BatchJobRunner(system_prompt, job_name, input_data_path, output_data_path)
    runner.run()

In [None]:
main()