In [None]:
import boto3
import pandas as pd
import base64
import json
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Optional
import time
from tqdm import tqdm

class ImageEmbeddingProcessor:
    def __init__(self, region: str = "us-east-1", max_workers: int = 10, batch_size: int = 100):
        self.region = region
        self.max_workers = max_workers
        self.batch_size = batch_size

        # Initialize AWS clients
        self.bedrock_client = boto3.client(
            "bedrock-runtime",
            region,
            endpoint_url=f"https://bedrock-runtime.{region}.amazonaws.com"
        )
        self.dynamodb = boto3.resource('dynamodb', region_name=region)
        self.table = self.dynamodb.Table('reverse_image_search')

        # Setup batch writer
        self.batch_writer = self.table.batch_writer()

    def process_batch(self, batch_df: pd.DataFrame) -> None:
        """Process a batch of images using thread pool"""
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            futures = []
            for _, row in batch_df.iterrows():
                future = executor.submit(self._process_single_image, row)
                futures.append(future)

            # Process completed futures with progress bar
            for future in tqdm(as_completed(futures), total=len(futures), desc="Processing batch"):
                try:
                    result = future.result()
                    if result:
                        self._store_embedding_batch(*result)
                except Exception as e:
                    print(f"Error processing image: {e}")

    def _process_single_image(self, row) -> Optional[tuple]:
        """Process a single image and return product_id and embedding"""
        try:
            modified_url = self._modify_image_url(row['image_one'])
            image_data = self._download_image(modified_url)

            if image_data is None:
                return None

            base64_encoded_image = base64.b64encode(image_data).decode('utf-8')
            embedding = self._create_image_embedding(base64_encoded_image)

            if embedding:
                return (row['product_id'], embedding)

        except Exception as e:
            print(f"Error processing product {row['product_id']}: {e}")
        return None

    def _store_embedding_batch(self, product_id: int, embedding: List[float]) -> None:
        """Store embedding using batch writer"""
        try:
            self.batch_writer.put_item(
                Item={
                    'product_id': product_id,
                    'image_embedding': json.dumps(embedding)
                }
            )
        except Exception as e:
            print(f"Error storing embedding for product {product_id}: {e}")

    @staticmethod
    def _modify_image_url(original_url: str, resolution: str = '448x448') -> str:
        parts = original_url.split('/')
        try:
            resolution_index = next(i for i, part in enumerate(parts) if 'x' in part)
            parts[resolution_index] = resolution
            return '/'.join(parts)
        except StopIteration:
            return original_url.replace('cloudfront.net/', f'cloudfront.net/{resolution}/')

    @staticmethod
    def _download_image(url: str) -> Optional[bytes]:
        try:
            response = requests.get(url, timeout=10)
            response.raise_for_status()
            return response.content
        except requests.RequestException as e:
            print(f"Error downloading image from {url}: {e}")
            return None

    def _create_image_embedding(self, image: str) -> Optional[List[float]]:
        try:
            response = self.bedrock_client.invoke_model(
                body=json.dumps({"inputImage": image}),
                modelId="amazon.titan-embed-image-v1",
                accept="application/json",
                contentType="application/json"
            )

            result = json.loads(response.get("body").read())
            if "message" in result:
                print(f"Error creating embeddings: {result['message']}")
                return None

            return result.get("embedding")

        except Exception as e:
            print(f"Error creating embedding: {e}")
            return None

    def process_dataset(self, df: pd.DataFrame, start_product_id: Optional[int] = None) -> None:
        """
        Process dataset in batches, optionally starting from a specific product_id

        Args:
            df (pd.DataFrame): Input DataFrame
            start_product_id (int, optional): Product ID to start processing from
        """
        # Filter DataFrame if start_product_id is provided
        if start_product_id is not None:
            df = df[df['product_id'] > start_product_id].copy()
            print(f"Starting from product_id > {start_product_id}")
            print(f"Remaining items to process: {len(df)}")

        if len(df) == 0:
            print("No items to process")
            return

        total_processed = 0
        start_time = time.time()

        for i in range(0, len(df), self.batch_size):
            batch_df = df[i:i + self.batch_size]
            self.process_batch(batch_df)
            total_processed += len(batch_df)

            # Print progress
            elapsed_time = time.time() - start_time
            rate = total_processed / elapsed_time
            print(f"\nProgress: {total_processed}/{len(df)} images")
            print(f"Processing rate: {rate:.2f} images/second")
            print(f"Estimated time remaining: {(len(df) - total_processed) / rate / 60:.2f} minutes")

            # Optional: Save checkpoint
            last_processed_id = batch_df['product_id'].max()
            print(f"Last processed product_id: {last_processed_id}")

# Usage example
processor = ImageEmbeddingProcessor(max_workers=10, batch_size=100)
df = pd.read_csv('sample_dataset.csv')
df = df.drop_duplicates(subset='product_id', keep='first')

# Start processing from a specific product_id
last_processed_id = 500  # Replace with your last processed product_id
processor.process_dataset(df, start_product_id=last_processed_id)