In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
from torchvision import transforms
from PIL import Image
import os
import numpy as np

# Load a pre-trained ResNet model
model = resnet18(weights=ResNet18_Weights.DEFAULT)
# Remove the final classification layer to get the feature vector
model = nn.Sequential(*list(model.children())[:-1])
model.eval() # Set the model to evaluation mode

# Define image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def extract_features(image_path):
    """Extracts features from an image using the ResNet model."""
    try:
        img = Image.open(image_path).convert('RGB')
        img_t = transform(img)
        batch_t = torch.unsqueeze(img_t, 0)
        
        with torch.no_grad():
            features = model(batch_t)
            
        return features.squeeze().numpy() # Convert tensor to numpy array
    except Exception as e:
        print(f"Could not process image {image_path}: {e}")
        return None

def process_image_directory(directory):
    """Processes all images in a directory and saves their features."""
    all_features = []
    image_paths = []
    
    for filename in os.listdir(directory):
        if filename.endswith(('.png', '.jpg', '.jpeg')):
            file_path = os.path.join(directory, filename)
            features = extract_features(file_path)
            if features is not None:
                all_features.append(features)
                image_paths.append(file_path)
                
    return np.array(all_features), image_paths

# --- Example Usage ---
# Assumes you have a folder named 'product_images' with your product photos
image_directory = 'product_images'
if os.path.exists(image_directory):
    features, paths = process_image_directory(image_directory)
    print(f"Extracted features for {len(features)} images.")
    # You would save these features and paths for the next step (FAISS)
    np.save('product_features.npy', features)
    # Also save the paths or product IDs to map back later
    with open('product_paths.txt', 'w') as f:
        for path in paths:
            f.write(f"{path}\n")
else:
    print(f"Please create a folder '{image_directory}' and add your images.")

In [None]:
!pip install faiss-cpu

In [None]:
import faiss
import numpy as np

# Load the features you saved from the previous step
features = np.load('product_features.npy')
# Load the corresponding product paths/IDs
with open('product_paths.txt', 'r') as f:
    paths = [line.strip() for line in f]

# Ensure data is in the correct format (float32)
features = features.astype('float32')

# Get the dimension of the feature vectors
d = features.shape[1]

# Build the FAISS index
# Faiss.IndexFlatL2 is a simple index that uses L2 (Euclidean) distance
index = faiss.IndexFlatL2(d)
print(f"FAISS index created with dimension {d}.")

# Add the feature vectors to the index
index.add(features)
print(f"Added {index.ntotal} vectors to the index.")

# --- Example Query ---
# Let's say you have a new query image's feature vector
# (This would be generated by the same ResNet model)
# For this example, we'll use one of our existing vectors as a query
query_vector = features[0:1] # Use the first product's vector as a test query

# Perform a search for the top 5 most similar items
k = 5
distances, indices = index.search(query_vector, k)

print("\n--- FAISS Search Results ---")
for i in range(k):
    print(f"Rank {i+1}: Product ID {indices[0][i]}, Distance: {distances[0][i]:.4f}")
    print(f"   Original Path: {paths[indices[0][i]]}")

# Save the index to a file for later use (e.g., in your Lambda function)
faiss.write_index(index, "product_index.faiss")
print("\nFAISS index saved to 'product_index.faiss'")

In [None]:
!pip install ultralytics


In [None]:
from ultralytics import YOLO
from PIL import Image

def get_product_crops(image_path):
    """Detects objects in an image and returns cropped images."""
    # Load a pre-trained YOLOv8 model
    model = YOLO('yolov8n.pt')  # You can choose different sizes (n, s, m, l, x)
    
    # Predict on the image
    results = model(image_path)
    
    crops = []
    
    for result in results:
        # Get bounding boxes for detected objects
        boxes = result.boxes.xyxy.cpu().numpy() # Get xyxy coordinates
        
        # Get the original image
        img = Image.open(image_path).convert('RGB')
        
        for box in boxes:
            # Crop the image using the bounding box
            x1, y1, x2, y2 = box.astype(int)
            cropped_img = img.crop((x1, y1, x2, y2))
            crops.append(cropped_img)
            
    return crops

# --- Example Usage ---
# Assuming you have an image file named 'test_product.jpg'
image_to_process = 'test_product.jpg'
cropped_images = get_product_crops(image_to_process)

if cropped_images:
    print(f"Found {len(cropped_images)} objects. The first one will be used for feature extraction.")
    # Now you would pass the first crop to the ResNet feature extractor
    # first_crop_features = extract_features(cropped_images[0])
    # ... then use FAISS to search with these features
else:
    print("No objects detected.")

In [None]:
import json
import boto3
import faiss
import numpy as np
import io
from PIL import Image
# Import your ResNet and YOLO models and functions
# from your_local_file import extract_features, get_product_crops

s3_client = boto3.client('s3')

# IMPORTANT: These are global variables to be loaded once on a cold start
# You need to load your pre-trained models and FAISS index here.
# This saves time and memory across subsequent requests.
# The 'faiss_index.faiss' and 'product_data.json' should be in your S3 bucket
# and downloaded to the Lambda's temporary storage (/tmp).
# Faiss index can be large, so consider splitting it or using a more
# memory-efficient index type for very large datasets.

# --- Mock-up of a Lambda function ---
def lambda_handler(event, context):
    try:
        # Get the uploaded image details from the S3 event trigger
        bucket_name = event['Records'][0]['s3']['bucket']['name']
        file_key = event['Records'][0]['s3']['object']['key']
        
        print(f"Processing new image: s3://{bucket_name}/{file_key}")
        
        # Download the image from S3
        response = s3_client.get_object(Bucket=bucket_name, Key=file_key)
        image_bytes = response['Body'].read()
        image_stream = io.BytesIO(image_bytes)
        
        # Step 1: Object Detection (YOLOv8)
        # Assuming get_product_crops is a function you've deployed
        # This will return a PIL Image object
        # cropped_image = get_product_crops(image_stream) 
        
        # Step 2: Feature Extraction (ResNet)
        # Assuming extract_features is a function you've deployed
        # query_vector = extract_features(cropped_image)
        # Mocking a vector for demonstration
        query_vector = np.random.rand(512).astype('float32').reshape(1, -1)
        
        # Step 3: Load FAISS index and perform search
        faiss_index = faiss.read_index('/tmp/product_index.faiss')
        k = 5
        distances, indices = faiss_index.search(query_vector, k)
        
        # Step 4: Retrieve and return product information
        # Assuming you have a metadata file mapping indices to product data
        # with open('/tmp/product_data.json', 'r') as f:
        #     all_product_data = json.load(f)
        
        recommendations = []
        for i, idx in enumerate(indices[0]):
            # Get data for the recommended product
            # product_info = all_product_data[str(idx)]
            # recommendations.append(product_info)
            recommendations.append({"index": int(idx), "distance": float(distances[0][i])})

        print("Recommendations found.")
        
        return {
            'statusCode': 200,
            'body': json.dumps({
                'message': 'Image processed successfully',
                'recommendations': recommendations
            })
        }
        
    except Exception as e:
        print(f"An error occurred: {e}")
        return {
            'statusCode': 500,
            'body': json.dumps({'error': str(e)})
        }