In [3]:
import json
import time
import os
from openai import OpenAI
from tqdm import tqdm


class OpenAIProcessor:
    def __init__(self, api_key, api_base, dataset_name, max_seq_len, template_id, model_name, incremental_mode=False):
        self.model_name = model_name
        self.template_id = template_id
        # print(template_id)
        self.client = self.initialize_client(api_key, api_base)
        self.input_file = './prompts/{}_{}/prompts_{}.json'.format(dataset_name, max_seq_len, template_id)
        self.output_path = './test/{}_{}/prompts_{}_{}/'.format(dataset_name, max_seq_len, template_id, model_name)
        if incremental_mode:
            self.input_file = './incremental_prompts/{}_{}/prompts_{}.json'.format(dataset_name, max_seq_len, template_id)
            self.output_path = './incremental_output/{}_{}/prompts_{}_{}/'.format(dataset_name, max_seq_len, template_id, model_name)

    def initialize_client(self, api_key, api_base):
        return OpenAI(api_key=api_key, base_url=api_base)

    def read_json(self, file_path):
        with open(file_path, 'r') as file:
            return json.load(file)

    def write_json(self, data, file_path):
        with open(file_path, 'w') as file:
            json.dump(data, file, indent=4)

    def call_api(self, content, model_name):
        content_parts = content.split("https://")
        # print(content_parts)
        text = content_parts[0].strip()
        
        if len(content_parts) > 1:
            image_url = content.split("https://")[1].strip()
            result = [
                {
                    "type": "text",
                    "text": text
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "detail": "low",
                        "url": "https://" + image_url
                    }
                }
            ]
            messages = [{'role': 'user', 'content': str(result)}]
            # print(messages)
        else:
            messages = [{'role': 'user', 'content': text}]
            # print(messages)
        
        return self.client.chat.completions.create(
            model=model_name,
            messages=messages,
            temperature=0.0,
            max_tokens=4096
        )

    
    def save_intermediate_results(self, processed_samples, failed_samples, processing_errors):
        # 读取并更新已处理样本
        processed_samples_path = os.path.join(self.output_path, 'processed_data.json')
        if os.path.exists(processed_samples_path):
            existing_processed_samples = self.read_json(processed_samples_path)
            existing_processed_samples.update(processed_samples)
            processed_samples = existing_processed_samples

        # 读取并更新失败样本
        failed_samples_path = os.path.join(self.output_path, 'failed_samples.json')
        if os.path.exists(failed_samples_path):
            existing_failed_samples = self.read_json(failed_samples_path)
            existing_failed_samples.extend([sample for sample in failed_samples if sample not in existing_failed_samples])
            failed_samples = existing_failed_samples

        # 读取并更新处理错误
        processing_errors_path = os.path.join(self.output_path, 'processing_errors.json')
        if os.path.exists(processing_errors_path):
            existing_processing_errors = self.read_json(processing_errors_path)
            existing_processing_errors.update(processing_errors)
            processing_errors = existing_processing_errors

        # 保存更新后的结果
        self.write_json(processed_samples, processed_samples_path)
        self.write_json(failed_samples, failed_samples_path)
        self.write_json(processing_errors, processing_errors_path)
        print('Sample Updated.')


    def process_sample(self, sample_data, model_name, timeout_duration=120):
        content = sample_data['prompt']
        #print(sample_data)
        if 's' in self.template_id or 'r-2' in self.template_id or 'r-1' in self.template_id:
            content += "\n" + sample_data['history']['online_combined_img_path']
        elif 'r-3' in self.template_id:
            try:
                #history_str = json.dumps(sample_data)
                content += "\n" 
            except TypeError as e:
                print(f"Error: Unable to convert history to string. {str(e)}")
                    # 在这里进行适当的错误处理

        # try:
            # print(content)
            # print(model_name)
        response = self.call_api(content, model_name)
            #print(response)
            # signal.alarm(0)  # Disable the alarm after successful completion
            # if response.choices:
        message_content = response.choices[0].message.content
        cleaned_content = message_content.replace('```json', '').replace('```', '').strip()
        print(json.loads(cleaned_content))
        return json.loads(cleaned_content)


    def process_samples(self, prompts_data, model_name, max_retries=2, timeout_duration=120):
        processed_samples = {}
        failed_samples = []
        processing_errors = {}

        sample_counter = 0
        for sample_id, sample_data in tqdm(prompts_data.items(), desc="Processing samples"):
            try:
                api_response = self.process_sample(sample_data, model_name)
                sample_data['api_response'] = api_response
                processed_samples[sample_id] = sample_data
                
                sample_counter += 1
                if sample_counter % 1 == 0:
                    self.save_intermediate_results(processed_samples, failed_samples, processing_errors)

            except Exception as e:
                failed_samples.append(sample_id)
                processing_errors[sample_id] = {'error': str(e)}

                for _ in range(max_retries):
                    try:
                        api_response = self.process_sample(sample_data, model_name)
                        sample_data['api_response'] = api_response
                        processed_samples[sample_id] = sample_data
                        failed_samples.remove(sample_id)
                        if sample_counter % 1 == 0:
                            self.save_intermediate_results(processed_samples, failed_samples, processing_errors)

                        break
                    except Exception as retry_error:
                        processing_errors[sample_id] = {'retry_error': str(retry_error)}

            # time.sleep()
        # 保存最终结果
        self.save_intermediate_results(processed_samples, failed_samples, processing_errors)

        return processed_samples, failed_samples, processing_errors
    
    def process_samples_and_save(self, resume_from_last=False, debug_mode=False):
        prompts_data = self.read_json(self.input_file)

        if debug_mode:
            print("Debug mode is ON: Processing only the first 10 samples.")
            prompts_data = dict(list(prompts_data.items())[:3])

        # Load previously processed samples if resume_from_last is True
        processed_samples = {}
        if resume_from_last:
            processed_samples_path = os.path.join(self.output_path, 'processed_data.json')
            if os.path.exists(processed_samples_path):
                processed_samples = self.read_json(processed_samples_path)
                print(f"Resuming from last session. {len(processed_samples)} samples already processed.")

        # Initialize the set of already processed sample IDs
        already_processed_ids = set(processed_samples.keys())

        # Create the output directory if it does not exist
        if not os.path.exists(self.output_path):
            os.makedirs(self.output_path)

        start_time = time.time()

        # Process samples
        new_processed_samples, failed_samples, processing_errors = self.process_samples(
            {k: v for k, v in prompts_data.items() if k not in already_processed_ids}, 
            self.model_name
        )

        # Update the processed_samples dictionary with new data
        processed_samples.update(new_processed_samples)

        # Save the results
        self.write_json(processed_samples, os.path.join(self.output_path, 'processed_data.json'))
        self.write_json(failed_samples, os.path.join(self.output_path, 'failed_samples.json'))
        self.write_json(processing_errors, os.path.join(self.output_path, 'processing_errors.json'))

        end_time = time.time()
        total_time = end_time - start_time
        average_time_per_sample = total_time / len(prompts_data)

        return self.output_path, total_time, average_time_per_sample


In [None]:
import concurrent.futures
# from tqdm.notebook import tqdm  # Specifically for Jupyter Notebook

# Usage
api_key = "xxxxxx"
api_base = "xxxxxx"
max_seq_len = 10

def run_task(dataset_name, template_id, model_name, incremental_mode, debug_mode):
    """
    Execute a single processing task using OpenAIProcessor based on the given parameters.

    Args:
        dataset_name (str): The name of the dataset to be processed. Used to differentiate between different datasets or tasks.
        template_id (str): The template ID used for the current task. Specifies the particular template or configuration for processing the data.
        model_name (str): The name of the OpenAI model to be used for the processing task. Determines which pre-trained model will be used in the API requests.
        incremental_mode (bool): Indicates whether to run in incremental mode. In incremental mode, processing will resume from where it last stopped.
        debug_mode (bool): Indicates whether to enable debug mode. In debug mode, only 10 samples are processed to quickly verify the logic.

    Returns:
        result: A data structure containing the processing results, typically including path, processing time, etc.
    """
    # Instantiate the OpenAIProcessor class with the given configuration parameters
    processor = OpenAIProcessor(api_key, api_base, dataset_name, max_seq_len, template_id, model_name, incremental_mode)
    # Call the process_samples_and_save method to process data and save the results
    result = processor.process_samples_and_save(resume_from_last=True, debug_mode=debug_mode)
    # Print task-related information
    print(f"Dataset: {dataset_name}, Template ID: {template_id}, Total time taken: {result[1]:.2f} seconds")
    print(f"Dataset: {dataset_name}, Template ID: {template_id}, Average time per sample: {result[2]:.2f} seconds")
    return result


tasks = [
    ('beauty', 'r-1', 'gpt-4o-2024-05-13', False, False),
    ('beauty', 'r-2', 'gpt-4o-2024-05-13', False, False),
    ('beauty', 'r-3', 'gpt-4o-2024-05-13', False, False),
    ('beauty', 'r-3', 'gpt-4-0125-preview', False, False),
    ('clothing', 'r-1', 'gpt-4o-2024-05-13', False, False),
    ('clothing', 'r-2', 'gpt-4o-2024-05-13', False, False),
    ('clothing', 'r-3', 'gpt-4o-2024-05-13', False, False),
    ('clothing', 'r-3', 'gpt-4-0125-preview', False, False),
    ('sports', 'r-1', 'gpt-4o-2024-05-13', False, False),
    ('sports', 'r-2', 'gpt-4o-2024-05-13', False, False),
    ('sports', 'r-3', 'gpt-4o-2024-05-13', False, False),
    ('sports', 'r-3', 'gpt-4-0125-preview', False, False),
    ('toys', 'r-1', 'gpt-4o-2024-05-13', False, False),
    ('toys', 'r-2', 'gpt-4o-2024-05-13', False, False),
    ('toys', 'r-3', 'gpt-4o-2024-05-13', False, False),
    ('toys', 'r-3', 'gpt-4-0125-preview', False, False),
    ('sports', 'r-1', 'gpt-4o-2024-05-13', False, False),
    ('beauty', 'r-1', 'gpt-4-vision-preview', False, False),
    ('beauty', 'r-2', 'gpt-4-vision-preview', False, False),
    ('beauty', 'r-3', 'gpt-4-vision-preview', False, False),
    ('clothing', 'r-1', 'gpt-4-vision-preview', False, False),
    ('clothing', 'r-2', 'gpt-4-vision-preview', False, False),
    ('clothing', 'r-3', 'gpt-4-vision-preview', False, False),
    ('sports', 'r-1', 'gpt-4-vision-preview', False, False),
    ('sports', 'r-2', 'gpt-4-vision-preview', False, False),
    ('sports', 'r-3', 'gpt-4-vision-preview', False, False),
    ('toys', 'r-1', 'gpt-4-vision-preview', False, False),
    ('toys', 'r-2', 'gpt-4-vision-preview', False, False),
    ('toys', 'r-3', 'gpt-4-vision-preview', False, False),
]


# Use ThreadPoolExecutor to run tasks in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=len(tasks)) as executor:
    # Start all tasks using a list comprehension
    futures = [executor.submit(run_task, ds_name, tpl_id, model_name, incremental_mode, debug_mode) for ds_name, tpl_id, model_name, incremental_mode, debug_mode in tasks]

    # Print the result of each future as it completes
    for future in concurrent.futures.as_completed(futures):
        future.result()