In [7]:
from flask import Flask, request, jsonify, send_file
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from PIL import Image
from sklearn.neighbors import NearestNeighbors
from io import BytesIO

In [8]:
def load_embeddings(embedding_file):
    return np.load(embedding_file, allow_pickle=True).item()

In [9]:
class MyServer:
    def __init__(self, database_embeddings=None):
        self.app = Flask(__name__)
        self.database_embeddings = database_embeddings
        self.db_matrix = None
        self.model = None
        self.nbrs = None
        self.setup_routes()
        self.load_model_and_embeddings()

    def setup_routes(self):
        @self.app.route('/search', methods=['POST'])
        def search():
            if 'image' not in request.files:
                return jsonify({"error": "No image uploaded"}), 400
            
            file = request.files['image']
            if file.filename == '':
                return jsonify({"error": "No image selected for uploading"}), 400

            # Save the uploaded file temporarily
            img_path = './temp_upload.jpg'
            file.save(img_path)
            
            # Get the embedding of the uploaded image
            input_embedding = self.get_image_embedding(img_path)
            
            # Find the closest images
            closest_images = self.find_closest_images_approx(input_embedding, top_k=5)
            
            # Optionally delete the temporary file
            os.remove(img_path)

            # Create a list of image files to return
            image_files = []
            for img_path in closest_images:
                image_files.append(img_path)

            # Return the actual images in the response
            return jsonify({"similar_images": image_files})

        @self.app.route('/images/<path:filename>', methods=['GET'])
        def get_image(filename):
            # Serve the image file
            return send_file(filename, mimetype='image/jpeg')

    def load_model_and_embeddings(self):
        print("Converting to matrix...")

        # Convert the embeddings to a matrix for easier distance computation
        self.db_matrix = np.vstack(list(self.database_embeddings.values()))
        self.db_matrix /= np.linalg.norm(self.db_matrix, axis=1)[:, np.newaxis]

        print("Converted to matrix successfully.")

        # Use NearestNeighbors to find approximate nearest neighbors
        self.nbrs = NearestNeighbors(n_neighbors=10, algorithm='auto').fit(self.db_matrix)

        print("Built NearestNeighbors model successfully.")

        # Load the ResNet50 model
        self.model = ResNet50(weights='imagenet', include_top=False, pooling='avg')

        print("Loaded ResNet50 model successfully.")

    def get_image_embedding(self, img_path):
        img = Image.open(img_path).resize((224, 224))
        x = np.array(img, dtype=np.float32)
        x = np.expand_dims(x, axis=0)
        x = preprocess_input(x)

        embedding = self.model.predict(x)
        return embedding

    def find_closest_images_approx(self, input_embedding, top_k=10):
        # Normalize the input embedding
        input_embedding = input_embedding.flatten()
        input_embedding /= np.linalg.norm(input_embedding)

        # Use NearestNeighbors to find approximate nearest neighbors
        distances, indices = self.nbrs.kneighbors([input_embedding])

        # Get the paths of the closest images
        closest_images = [list(self.database_embeddings.keys())[i] for i in indices[0]]

        return closest_images

In [10]:
if __name__ == '__main__':
    print("Loading embeddings...")
    database_embeddings = load_embeddings('database_embeddings_all.npy')
    print("Loaded embeddings successfully.")
    server = MyServer(database_embeddings)
    server.app.run(host='0.0.0.0', port=6014, debug=True)

Loading embeddings...
Loaded embeddings successfully.
Converting to matrix...
Converted to matrix successfully.
Built NearestNeighbors model successfully.
Loaded ResNet50 model successfully.
 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:6014
 * Running on http://192.168.178.22:6014
Press CTRL+C to quit
 * Restarting with stat


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
