In [None]:
import os
import json
import base64
import asyncio
import aiohttp
import time
from anthropic import Anthropic
from PIL import Image
import io
import pandas as pd
import numpy as np
import random
from sklearn.utils import shuffle
import nest_asyncio
from tqdm import tqdm

nest_asyncio.apply()

def load_image(image_path: str) -> str:
    """
    Load image from file, convert to JPEG, and encode as base64.
    """
    try:
        with Image.open(image_path) as img:
            if img.mode != 'RGB':
                img = img.convert('RGB')
            buffer = io.BytesIO()
            img.save(buffer, format="JPEG", quality=95)
            return base64.b64encode(buffer.getvalue()).decode('utf-8')
    except Exception as e:
        print(f"Error processing image {image_path}: {str(e)}")
        return None

class RateLimiter:
    def __init__(self, max_requests, time_window):
        self.max_requests = max_requests
        self.time_window = time_window
        self.request_times = []

    async def wait(self):
        while True:
            current_time = time.time()
            self.request_times = [t for t in self.request_times if t > current_time - self.time_window]
            if len(self.request_times) < self.max_requests:
                self.request_times.append(current_time)
                break
            await asyncio.sleep(0.1)

class GPTAPI:
    def __init__(self, api_key, model="gpt-4o-2024-05-13"):
        self.api_key = api_key
        self.model = model
        self.url = "https://api.openai.com/v1/chat/completions"
        self.headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}"
        }
        self.rate_limiter = RateLimiter(max_requests=5, time_window=30)

    async def get_image_information(self, inputs: dict) -> str:
        await self.rate_limiter.wait()
        payload = {
            "model": self.model,
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": inputs['prompt']},
                        *inputs['examples'],
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{inputs['image']}",
                                "detail": "high"
                            }
                        }
                    ]
                }
            ],
            "max_tokens": 120000, 
            "temperature":0.5
        }
        async with aiohttp.ClientSession() as session:
            async with session.post(self.url, headers=self.headers, json=payload) as response:
                result = await response.json()
                if "choices" in result and result["choices"]:
                    return result["choices"][0]['message']['content']
                else:
                    raise Exception(f"Unexpected API response format: {result}")

class ClaudeAPI:
    def __init__(self, api_key, model="claude-3-sonnet-20240229"):
        self.client = Anthropic(api_key=api_key)
        self.model = model
        self.rate_limiter = RateLimiter(max_requests=5, time_window=60)  # Adjust these values as needed

    async def get_image_information(self, inputs: dict) -> str:
        await self.rate_limiter.wait()
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": inputs['prompt']},
                    *[
                        {
                            "type": "image",
                            "source": {
                                "type": "base64",
                                "media_type": "image/jpeg",
                                "data": ex['image_url']['url'].split(',')[1] if ex['type'] == 'image_url' else ex['source']['data']
                            }
                        }
                        if ex['type'] in ['image_url', 'image'] else ex
                        for ex in inputs['examples']
                    ],
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/jpeg",
                            "data": inputs['image']
                        }
                    }
                ]
            }
        ]
        response = self.client.messages.create(
            model=self.model,
            max_tokens=120000,
            temperature=0.5,
            messages=messages
        )
        return response.content[0].text
# Set the base directory


base_directory = '/Users/muhammadarbabarshad/Downloads/AgEval-datasets/severity-based-rice-disease/train'
model_type = 'gpt'  # Change this to 'claude' to use Claude API
total_samples_to_check = 10

# Define the list of expected class labels
expected_classes = ['Healthy', 'Mild Bacterial Blight', 'Mild Blast', 'Mild Brownspot', 'Mild Tungro', 'Severe Bacterial Blight', 'Severe Blast', 'Severe Brownspot', 'Severe Tungro']

def load_and_prepare_data(base_directory, total_samples_to_check):
    num_instances_per_class = int(total_samples_to_check / len(expected_classes))
    file_paths = []
    labels = []

    for subdir in os.listdir(base_directory):
        if subdir == ".DS_Store":
            continue
        subdir_path = os.path.join(base_directory, subdir)
        for filename in os.listdir(subdir_path):
            if filename.lower().endswith(('.jpg', '.jpeg', '.png')):
                file_paths.append(os.path.join(subdir_path, filename))
                labels.append(subdir)

    data = pd.DataFrame({0: file_paths, 1: labels})
    sampled_data = pd.DataFrame(columns=[0, 1])

    for cls in data[1].unique():
        class_data = data[data[1] == cls].sample(n=num_instances_per_class, random_state=42)
        sampled_data = pd.concat([sampled_data, class_data], ignore_index=True)

    return shuffle(sampled_data, random_state=42).reset_index(drop=True)

vision_prompt = f"""
Given the image, identify the class. Use the following list of possible classes for your prediction It should be one of the : {expected_classes}. Be attentive to subtle details as some classes may appear similar. Provide your answer in the following JSON format:
{{"prediction": "class_name"}}
Replace "class_name" with the appropriate class from the list above based on your analysis of the image.
The labels should be entered exactly as they are in the list above i.e., {expected_classes}.
The response should start with {{ and contain only a JSON object (as specified above) and no other text.
"""

class ProgressBar:
    def __init__(self, total):
        self.pbar = tqdm(total=total, desc="Processing images")

    def update(self):
        self.pbar.update(1)

    def close(self):
        self.pbar.close()

async def process_image(api, i, number_of_shots, all_data_results, all_data, progress_bar):
    try:
        image_path = all_data[0][i]
        image_base64 = load_image(image_path)
        if image_base64 is None:
            raise ValueError(f"Failed to load image: {image_path}")
        
        examples = []
        num_rows = len(all_data)
        random_indices = random.sample([idx for idx in range(num_rows) if idx != i], number_of_shots)

        for j in random_indices:
            example_image_path = all_data[0][j]
            example_image_base64 = load_image(example_image_path)
            if example_image_base64 is not None:
                if isinstance(api, GPTAPI):
                    examples.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{example_image_base64}", "detail": "high"}})
                elif isinstance(api, ClaudeAPI):
                    examples.append({
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/jpeg",
                            "data": example_image_base64
                        }
                    })
                examples.append({"type": "text", "text": f'{{"prediction": "{all_data.at[j, 1]}"}}' })

        prediction = await api.get_image_information({"image": image_base64, "examples": examples, "prompt": vision_prompt})
        
        try:
            parsed_prediction = json.loads(prediction)['prediction']
        except json.JSONDecodeError:
            print(f"Error parsing JSON for image {image_path}. API response: {prediction}")
            parsed_prediction = 'Error: Invalid JSON'
        except KeyError:
            print(f"Missing 'prediction' key for image {image_path}. API response: {prediction}")
            parsed_prediction = 'Error: Missing prediction'

        all_data_results.at[i, f"# of Shots {number_of_shots}"] = parsed_prediction

    except Exception as e:
        print(f"Error processing {all_data[0][i]}: {str(e)}")
        all_data_results.at[i, f"# of Shots {number_of_shots}"] = f'Error: {str(e)}'
    finally:
        progress_bar.update()

async def process_images_for_shots(api, number_of_shots, all_data_results, all_data):
    progress_bar = ProgressBar(len(all_data))
    tasks = []
    for i in range(len(all_data)):
        task = asyncio.ensure_future(process_image(api, i, number_of_shots, all_data_results, all_data, progress_bar))
        tasks.append(task)
    
    await asyncio.gather(*tasks)
    progress_bar.close()

async def main():
    all_data = load_and_prepare_data(base_directory, total_samples_to_check)
    all_data_results = all_data.copy(deep=True)
    all_data_results.columns = all_data_results.columns.map(str)

    if model_type == "gpt":
        api = GPTAPI(api_key=os.getenv("OPENAI_API_KEY"))
    elif model_type == "claude":
        api = ClaudeAPI(api_key=os.getenv("ANTHROPIC_API_KEY"))
    else:
        raise ValueError(f"Unsupported model type: {model_type}")

    for number_of_shots in [0]:  # [0, 1, 2, 4, 8]:
        await process_images_for_shots(api, number_of_shots, all_data_results, all_data)
    
    output_dir = f"results_{model_type}"
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f"SBRD.csv")
    all_data_results.to_csv(output_file)
    print(f"Results saved to {output_file}")

if __name__ == "__main__":
    asyncio.run(main())