In [None]:
import os
import asyncio
import base64
import io
import openai
from openai import AsyncOpenAI
from PIL import Image
from tqdm.asyncio import tqdm_asyncio
from sklearn.metrics import accuracy_score
import numpy as np

# --- 0. Setup API Client and Environment ---
# Using the API configuration from the task description
print("Setting up Gemini API client...")
os.environ["OPENAI_API_KEY"] = "sk-yxhm7vIgkeffD0FU1bE5F797B654482d94CbB6DbBa556b96"
os.environ["OPENAI_BASE_URL"] = "https://api.ai-gaochao.cn/v1"

client = AsyncOpenAI()

# --- 1. Prepare Data and Prompt ---
# Reusing variables from Subtask 2
# images_to_test: A list of 1010 PIL Image objects
# true_labels: A list of 1010 integer labels
# class_names: A list of 101 string class names

# Helper function to convert PIL image to base64
def pil_to_base64(image: Image.Image, format="jpeg") -> str:
    buffer = io.BytesIO()
    image.save(buffer, format=format)
    return base64.b64encode(buffer.getvalue()).decode("utf-8")

# Create a string of all possible class names for the prompt
class_list_str = ", ".join(class_names)

# Define the prompt template
# This prompt guides the model to return only the class name for easy parsing.
PROMPT_TEMPLATE = f"""
You are an expert food classifier. Your task is to identify the food in the image and respond with ONLY the corresponding class name from the provided list. Do not add any extra text, explanations, or punctuation.

Here are some examples of correct responses:
- If you see a picture of a hamburger, you should respond with: hamburger
- If you see a picture of sushi, you should respond with: sushi

Now, identify the food in the following image.

List of possible classes: {class_list_str}

Your answer:
"""

print(f"Data and prompt prepared. Will classify {len(images_to_test)} images.")




# --- 2. Define Asynchronous API Call Function ---
# MODIFICATION: Added a 'semaphore' argument
async def classify_image_with_gemini(image: Image.Image, session_client: AsyncOpenAI, semaphore: asyncio.Semaphore):
    # MODIFICATION: Acquire the semaphore before making the API call
    async with semaphore:
        base64_image = pil_to_base64(image)
        
        try:
            response = await session_client.chat.completions.create(
                model="gemini-2.5-flash", # As specified in the task
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": PROMPT_TEMPLATE},
                            {
                                "type": "image_url",
                                "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
                            },
                        ],
                    }
                ],
                max_tokens=50, # A class name is short, so we don't need many tokens
                temperature=0.0, # Set to 0 for deterministic, most likely output
            )
            # Extract the text and clean up any extra whitespace or quotes
            return response.choices[0].message.content.strip().strip('"')
        except Exception as e:
            # If an API call fails, return an error string
            return f"API_ERROR: {str(e)}"





# --- 3. Run Asynchronous Classification ---
async def run_classification():
    print("\nStarting classification with Gemini 2.5 Flash...")
    # MODIFICATION: Create a semaphore to limit concurrency to 50
    concurrency_limit = 50
    semaphore = asyncio.Semaphore(concurrency_limit)
    
    # MODIFICATION: Pass the semaphore to each task
    tasks = [classify_image_with_gemini(img, client, semaphore) for img in images_to_test]
    
    # tqdm_asyncio shows a progress bar for our async tasks
    predictions = await tqdm_asyncio.gather(*tasks)
    return predictions

# Run the main async function
# Note: This will make 1010 API calls and may take several minutes and incur costs.
gemini_predictions_str = await run_classification()
print("All API calls completed.")

# --- 4. Calculate Accuracy and Compare ---
# Convert predicted string labels to integer IDs
# Create a mapping from class name string to integer ID
name_to_id_map = {name: i for i, name in enumerate(class_names)}

predicted_labels = []
api_errors = 0
invalid_responses = 0
for pred_str in gemini_predictions_str:
    if "API_ERROR" in pred_str:
        api_errors += 1
        predicted_labels.append(-1) # Mark as incorrect
        continue

    # Clean the model's output string
    # 1. Convert to lowercase
    # 2. Replace spaces and hyphens with underscores
    # 3. Remove common punctuation and extra words
    cleaned_str = pred_str.lower().replace(' ', '_').replace('-', '_')
    
    # Find the best matching class name in the cleaned string
    found_match = False
    for class_name in class_names:
        if class_name in cleaned_str:
            predicted_labels.append(name_to_id_map[class_name])
            found_match = True
            break # Stop after finding the first match
    
    if not found_match:
        invalid_responses += 1
        predicted_labels.append(-1) # Mark as incorrect if no class name is found
# --- MODIFICATION END ---

# Ensure lists are numpy arrays for metric calculation
true_labels_np = np.array(true_labels)
predicted_labels_np = np.array(predicted_labels)

# Calculate Top-1 Accuracy
gemini_accuracy = accuracy_score(true_labels_np, predicted_labels_np)

print("\n--- Subtask 5 Results ---")
print(f"Total images processed: {len(images_to_test)}")
print(f"Successful API calls: {len(images_to_test) - api_errors}")
print(f"API errors: {api_errors}")
print(f"Invalid/unrecognized responses: {invalid_responses}")
print(f"\nGemini 2.5 Flash Top-1 Accuracy: {gemini_accuracy:.4f}")

if 'top5_accuracy' in locals() or 'top5_accuracy' in globals():
    print(f"For comparison, SigLIP Zero-Shot Top-5 Accuracy (from Subtask 2): {top5_accuracy:.4f}")
else:
    print("\n(SigLIP Zero-Shot Top-5 Accuracy from Subtask 2 was not calculated in this session.)")
    
print("\nSubtask 5 Finished!")