# LVLMRecommender Usage Guide
This notebook demonstrates how to use the `LVLMRecommender` class to process samples from a dataset using either OpenAI or Claude APIs. The following steps outline the process of setting up and executing the tasks.

## Step 1: Import Necessary Libraries

First, import all necessary libraries and modules.

In [None]:
import json
import time
import os
from openai import OpenAI
from anthropic import Anthropic
import base64
import httpx
from tqdm import tqdm
import concurrent.futures


## Step 2: Define the LVLMRecommender Class

The `LVLMRecommender` class encapsulates all the functionality required to interact with the APIs and process the samples. Ensure the class definition is included in the notebook.


In [None]:
class LVLMRecommender:
    def __init__(self, api_type, api_key, base_url, dataset_name, max_seq_len, template_id, model_name, incremental_mode=False):
        """
        Initialize the LVLMRecommender with necessary parameters.
        """
        self.api_type = api_type
        self.model_name = model_name
        self.template_id = template_id
        self.client = self.initialize_client(api_key, base_url)
        self.input_file = f'./prompts/sampled_prompts/{dataset_name}_{max_seq_len}/prompts_{template_id}.json'
        self.output_path = f'./results/{dataset_name}_{max_seq_len}/prompts_{template_id}_{model_name}/'

    def initialize_client(self, api_key, base_url):
        """
        Initialize the API client.
        """
        if self.api_type == 'openai':
            return OpenAI(api_key=api_key, base_url=base_url)
        elif self.api_type == 'claude':
            return Anthropic(base_url=base_url, auth_token=api_key)
        else:
            raise ValueError("Invalid API type. Choose 'openai' or 'claude'.")

    def read_json(self, file_path):
        """
        Read and return data from a JSON file.
        """
        with open(file_path, 'r') as file:
            return json.load(file)

    def write_json(self, data, file_path):
        """
        Write data to a JSON file.
        """
        with open(file_path, 'w') as file:
            json.dump(data, file, indent=4)

    def call_api(self, content, model_name):
        """
        Call the appropriate API with the provided content.
        """
        if self.api_type == 'openai':
            return self.call_openai_api(content, model_name)
        elif self.api_type == 'claude':
            return self.call_claude_api(content, model_name)
        else:
            raise ValueError("Invalid API type. Choose 'openai' or 'claude'.")

    def call_openai_api(self, content, model_name):
        """
        Call the OpenAI API with the provided content.
        """
        content_parts = content.split("https://")
        text = content_parts[0].strip()

        if len(content_parts) > 1:
            image_url = "https://" + content_parts[1].strip()
            result = [
                {"type": "text", "text": text},
                {"type": "image_url", "image_url": {"detail": "low", "url": image_url}}
            ]
            messages = [{'role': 'user', 'content': str(result)}]
        else:
            messages = [{'role': 'user', 'content': text}]
        
        return self.client.chat.completions.create(
            model=model_name,
            messages=messages,
            temperature=0.0,
            max_tokens=4096
        )

    def call_claude_api(self, content, model_name):
        """
        Call the Claude API with the provided content.
        """
        content_parts = content.split("https://")
        text = content_parts[0].strip()
        
        if len(content_parts) > 1:
            image_url = "https://" + content_parts[1].strip()
            image1_data = base64.b64encode(httpx.get(image_url).content).decode("utf-8")
            response = self.client.messages.create(
                model=model_name,
                temperature=0.0,
                max_tokens=1024,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "image",
                                "source": {
                                    "type": "base64",
                                    "media_type": "image/png",
                                    "data": image1_data,
                                },
                            },
                            {
                                "type": "text",
                                "text": text + ' Just output JSON format without any description, need to generate a complete JSON format.'
                            }
                        ],
                    }
                ],
            )
        else:
            response = self.client.messages.create(
                model=model_name,
                temperature=0.0,
                max_tokens=1024,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": text + ' Just output completed JSON format without any description. \nJust output the first 30 characters of each recommendation item.'
                            }
                        ],
                    }
                ],
            )
        return response

    def save_intermediate_results(self, processed_samples, failed_samples, processing_errors):
        """
        Save intermediate results to JSON files.
        """
        processed_samples_path = os.path.join(self.output_path, 'processed_data.json')
        if os.path.exists(processed_samples_path):
            existing_processed_samples = self.read_json(processed_samples_path)
            existing_processed_samples.update(processed_samples)
            processed_samples = existing_processed_samples

        failed_samples_path = os.path.join(self.output_path, 'failed_samples.json')
        if os.path.exists(failed_samples_path):
            existing_failed_samples = self.read_json(failed_samples_path)
            existing_failed_samples.extend([sample for sample in failed_samples if sample not in existing_failed_samples])
            failed_samples = existing_failed_samples

        processing_errors_path = os.path.join(self.output_path, 'processing_errors.json')
        if os.path.exists(processing_errors_path):
            existing_processing_errors = self.read_json(processing_errors_path)
            existing_processing_errors.update(processing_errors)
            processing_errors = existing_processing_errors

        self.write_json(processed_samples, processed_samples_path)
        self.write_json(failed_samples, failed_samples_path)
        self.write_json(processing_errors, processing_errors_path)
        print('Sample Updated.')

    def process_sample(self, sample_data, model_name, timeout_duration=120):
        """
        Process a single sample by calling the API.
        """
        content = sample_data['prompt']
        if self.template_id in ["s-1-image", "s-1-title-image", "s-2", "s-3"]:
            content += "\n" + sample_data['history']['online_combined_img_path']

        response = self.call_api(content, model_name)
        message_content = response['content'][0]['text']
        cleaned_content = message_content.replace('```json', '').replace('```', '').strip()
        return json.loads(cleaned_content)

    def process_samples(self, prompts_data, model_name, max_retries=2, timeout_duration=120):
        """
        Process multiple samples by calling the API and handling retries for failed samples.
        """
        processed_samples = {}
        failed_samples = []
        processing_errors = {}

        sample_counter = 0
        for sample_id, sample_data in tqdm(prompts_data.items(), desc="Processing samples"):
            try:
                api_response = self.process_sample(sample_data, model_name)
                sample_data['api_response'] = api_response
                processed_samples[sample_id] = sample_data
                
                sample_counter += 1
                if sample_counter % 1 == 0:
                    self.save_intermediate_results(processed_samples, failed_samples, processing_errors)

            except Exception as e:
                failed_samples.append(sample_id)
                processing_errors[sample_id] = {'error': str(e)}

                for _ in range(max_retries):
                    try:
                        api_response = self.process_sample(sample_data, model_name)
                        sample_data['api_response'] = api_response
                        processed_samples[sample_id] = sample_data
                        failed_samples.remove(sample_id)
                        if sample_counter % 1 == 0:
                            self.save_intermediate_results(processed_samples, failed_samples, processing_errors)

                        break
                    except Exception as retry_error:
                        processing_errors[sample_id] = {'retry_error': str(retry_error)}
        self.save_intermediate_results(processed_samples, failed_samples, processing_errors)

        return processed_samples, failed_samples, processing_errors

    def process_samples_and_save(self, resume_from_last=False, debug_mode=False):
        """
        Process samples and save the results to JSON files.
        """
        prompts_data = self.read_json(self.input_file)

        if debug_mode:
            print("Debug mode is ON: Processing only the first 10 samples.")
            prompts_data = dict(list(prompts_data.items())[:3])

        # Load previously processed samples if resume_from_last is True
        processed_samples = {}
        if resume_from_last:
            processed_samples_path = os.path.join(self.output_path, 'processed_data.json')
            if os.path.exists(processed_samples_path):
                processed_samples = self.read_json(processed_samples_path)
                print(f"Resuming from last session. {len(processed_samples)} samples already processed.")

        # Initialize the set of already processed sample IDs
        already_processed_ids = set(processed_samples.keys())
            # Create the output directory if it does not exist
        if not os.path.exists(self.output_path):
            os.makedirs(self.output_path)

        start_time = time.time()

        # Process samples
        new_processed_samples, failed_samples, processing_errors = self.process_samples(
            {k: v for k, v in prompts_data.items() if k not in already_processed_ids}, 
            self.model_name
        )

        # Update the processed_samples dictionary with new data
        processed_samples.update(new_processed_samples)

        # Save the results
        self.write_json(processed_samples, os.path.join(self.output_path, 'processed_data.json'))
        self.write_json(failed_samples, os.path.join(self.output_path, 'failed_samples.json'))
        self.write_json(processing_errors, os.path.join(self.output_path, 'processing_errors.json'))

        end_time = time.time()
        total_time = end_time - start_time
        average_time_per_sample = total_time / len(prompts_data)

        return self.output_path, total_time, average_time_per_sample

## Step 3: Set Up Configuration Parameters

Define the configuration parameters such as API keys, base URLs, dataset names, sequence length, template IDs, and model names. These parameters will be used to instantiate the `LVLMRecommender` class.


In [1]:
api_key = "your_api_key_here"
base_url = "your_base_url_here"
dataset_name = "your_dataset_name_here"
max_seq_len = 10
template_id = "your_template_id_here"
model_name = "your_model_name_here"

## Step 4: Define the Task Execution Function

Create a function to execute a single processing task using the `LVLMRecommender` class based on the given parameters. This function will handle processing and saving the results.


In [None]:

def run_task(api_type, dataset_name, template_id, model_name, debug_mode):
    """
    Execute a single processing task using LVLMRecommender based on the given parameters.
    """
    # Instantiate the LVLMRecommender class with the given configuration parameters
    processor = LVLMRecommender(api_type, api_key, base_url, dataset_name, max_seq_len, template_id, model_name)
    # Call the process_samples_and_save method to process data and save the results
    result = processor.process_samples_and_save(resume_from_last=True, debug_mode=debug_mode)
    # Print task-related information
    print(f"Dataset: {dataset_name}, Template ID: {template_id}, Total time taken: {result[1]:.2f} seconds")
    print(f"Dataset: {dataset_name}, Template ID: {template_id}, Average time per sample: {result[2]:.2f} seconds")
    return result


## Step 5: Define the Tasks

Specify the tasks you want to execute. Each task is a tuple containing the API type, dataset name, template ID, model name, and a debug mode flag.


In [None]:
tasks = [
    ('openai', 'beauty', 'r-1', 'gpt-4o-2024-05-13', False),
    # Add more tasks as needed
]

## Step 6: Execute the Tasks Using ThreadPoolExecutor

Use `concurrent.futures.ThreadPoolExecutor` to run the tasks in parallel. This allows for efficient execution of multiple tasks concurrently.


In [None]:
with concurrent.futures.ThreadPoolExecutor(max_workers=len(tasks)) as executor:
    # Start all tasks using a list comprehension
    futures = [executor.submit(run_task, api_type, ds_name, tpl_id, model_name, debug_mode) for api_type, ds_name, tpl_id, model_name, debug_mode in tasks]

    # Print the result of each future as it completes
    for future in concurrent.futures.as_completed(futures):
        future.result()
