In [3]:
import psycopg2 as pg2
import cv2
import os
import numpy as np
import glob
from pgvector.psycopg2 import register_vector
import re
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT

In [11]:
class ImageMatcher:
    def __init__(self):
        self.pgconn = self.connect_to_database()
        self.cursor = self.pgconn.cursor()
        
    def connect_to_database(self):
        try:
            pgconn = pg2.connect(
                host='localhost',
                user='postgres',
                password='test',
                database='vector_db',
                port=5432
            )
            return pgconn
        except psycopg2.InterfaceError as e:
            print('{} - connection will be reset'.format(e))
            # Close old connection 
            if pgconn:
                if curs:
                    curs.close()
                pgconn.close()
            pgconn = None
            cursor = None

            # Reconnect 
            pgconn = psycopg2.connect(user='postgres',
                                    password='test',
                                    host='localhost',
                                    port=5432,
                                    database='vector_db'
                                    )
            return pgconn

        #self.curs = self.pgconn.cursor()
        self.curs.execute('CREATE EXTENSION IF NOT EXISTS vector')
        self.curs.execute("ROLLBACK")
        pgconn.commit()

        return pgconn

    def create_table(self):
        self.curs.execute("""
        CREATE TABLE IF NOT EXISTS training_data (
            id SERIAL PRIMARY KEY,
            image_path TEXT NOT NULL,
            class_type TEXT NOT NULL,
            file_name TEXT NOT NULL,
            keypoints VECTOR NOT NULL,
            descriptors VECTOR NOT NULL
        )
        """)
        self.pgconn.commit()

    def create_query_table(self):
        self.curs.execute("""
        CREATE TABLE IF NOT EXISTS query_data (
            id SERIAL PRIMARY KEY,
            query_des VECTOR NOT NULL,
            query_kp VECTOR NOT NULL
        )
        """)
        self.pgconn.commit()
        
        
    def imageResizeTrain(self, image):
        maxD = 1024
        height, width = image.shape
        aspectRatio = width / height
        if aspectRatio < 1:
            newSize = (int(maxD * aspectRatio), maxD)
        else:
            newSize = (maxD, int(maxD / aspectRatio))
        image = cv2.resize(image, newSize)
        return image

    def imageResizeTest(self, image):
        maxD = 1024
        height, width, channel = image.shape
        aspectRatio = width / height
        if aspectRatio < 1:
            newSize = (int(maxD * aspectRatio), maxD)
        else:
            newSize = (maxD, int(maxD / aspectRatio))
        image = cv2.resize(image, newSize)
        return image

    def compute_sift(self, image):
        # Create SIFT object and detect keypoints and compute descriptors for the input image
        sift = cv2.SIFT_create()
        return sift.detectAndCompute(self.imageResizeTrain(cv2.imread(image, cv2.IMREAD_GRAYSCALE)), None)    
    
    
    def string_to_list(self, input_string):
        # Split the string by comma and convert each element to float
        return [float(x.strip()) for x in input_string.split(',')]

    def pad_vector(self, vector, target_dimension):
        """
        Pad the vector with zeros to match the target dimension.
        """
        if len(vector) < target_dimension:
            padding = [0] * (target_dimension - len(vector))
            vector += padding
        return vector

    def get_keypoints_descriptors(self, img_folder_path):
        print("Loading images from: ", img_folder_path)
        file_extensions = ['jpg', 'jpeg']
        for ext in file_extensions:
            for file in glob.glob(f"{img_folder_path}/**/*.{ext}", recursive=True):
                sift = cv2.SIFT_create()
                keypointTemp, descriptorTemp = sift.detectAndCompute(self.imageResizeTrain(cv2.imread(file, cv2.IMREAD_GRAYSCALE)), None)
                keypoints = []
                descriptors = []
                keypoints.append(keypointTemp)
                descriptors.append(descriptorTemp)
                for i, keypoint in enumerate(keypoints):
                    deserializedKeypoints = []
                    for point in keypoint:
                        temp = (point.pt, point.size, point.angle, 
                                point.response, point.octave, 
                                point.class_id)
                        deserializedKeypoints.append(temp)
                
                deserializedKeypoints = ", ".join(map(str, deserializedKeypoints)).replace(")", "").replace("(", "")
                deserializedKeypoints = self.string_to_list(deserializedKeypoints)
                deserializedKeypoints = deserializedKeypoints[:15900]
                deserializedKeypoints = self.pad_vector(deserializedKeypoints, target_dimension=15900)
                
                descriptors = np.array(descriptors).ravel().tolist()
                des_vector = descriptors[:15900]
                des_vector = self.pad_vector(des_vector, target_dimension=15900)
                
                register_vector(self.pgconn)
                class_type = os.path.basename(os.path.dirname(file))
                file_name = os.path.basename(file)
                self.cursor.execute("INSERT INTO training_data (image_path, class_type, file_name, keypoints, descriptors) VALUES (%s, %s, %s, %s, %s)",
                        (file, class_type, file_name, deserializedKeypoints, des_vector))
                self.pgconn.commit()
                #self.cursor.close()
                
                
    def calculate_similarity(self, query_des, query_kp, threshold):
        # Insert the query descriptor and keypoint into the database
        self.cursor.execute("INSERT INTO query_data (query_des, query_kp) VALUES (%s, %s)", (query_des, query_kp))
        self.pgconn.commit()

        # Construct the SQL query to calculate similarity
        sSQL = """
    SELECT 
        training_data.class_type,
        training_data.descriptors <-> query_data.query_des AS desc_distance,
        training_data.keypoints <-> query_data.query_kp AS kp_distance
    FROM 
        training_data, query_data
    WHERE 
        training_data.descriptors <-> query_data.query_des >= %s
    AND
        training_data.keypoints <-> query_data.query_kp >= %s
    ORDER BY 
        (training_data.descriptors <-> query_data.query_des,
        training_data.keypoints <-> query_data.query_kp) ASC
    LIMIT 1;
        """

        # Execute the query with threshold and fetch the result
        self.cursor.execute(sSQL, (threshold, threshold))
        result = self.cursor.fetchone()
        # Delete the inserted query data
        self.cursor.execute("DELETE FROM query_data")
        self.pgconn.commit()

        # Check if result is not None
        if result is not None:
            # Fetch the distances from the result
            desc_distance = result[1]
            class_type = result[0]

            # Combined similarity is just descriptor distance for now
            combined_similarity = desc_distance
        else:
            combined_similarity = 0  # Default similarity if no match found
            class_type = "UNKNOWN"

        return combined_similarity, class_type

    def find_validation_image_class_type(self, query_image_path):
        sift = cv2.SIFT_create()
        keypointTemp, descriptorTemp = sift.detectAndCompute(self.imageResizeTrain(cv2.imread(query_image_path, cv2.IMREAD_GRAYSCALE)), None)
        keypoints = []
        descriptors = []
        keypoints.append(keypointTemp)
        descriptors.append(descriptorTemp)
        for i, keypoint in enumerate(keypoints):
            deserializedKeypoints = []
            for point in keypoint:
                temp = (point.pt, point.size, point.angle, 
                        point.response, point.octave, point.class_id)
                deserializedKeypoints.append(temp)

        deserializedKeypoints = ", ".join(map(str, deserializedKeypoints)).replace(")", "").replace("(", "")
        deserializedKeypoints = self.string_to_list(deserializedKeypoints)
        query_kp = deserializedKeypoints[:15900]
        query_kp = self.pad_vector(query_kp, target_dimension=15900)
        descriptors = np.array(descriptors).ravel().tolist()
        query_des = descriptors[:15900]
        query_des = self.pad_vector(query_des, target_dimension=15900)
        threshold = 0  # Set the threshold for similarity
    
        # Calculate similarity using the provided function
        combined_similarity, class_type = self.calculate_similarity(query_des, query_kp, threshold)
        return class_type
    
    

# Instantiate the ImageMatcher class
image_matcher = ImageMatcher()

# Call the get_keypoints_descriptors method to populate the training data
train_folder_path = "./img_test/train"
image_matcher.get_keypoints_descriptors(train_folder_path)

# Call the find_validation_image_class_type method to find the class type of the query image
query_image_path = "./img_test/val/beautiful-blessings/beautiful-blessings-check-3.jpeg"
class_type = image_matcher.find_validation_image_class_type(query_image_path)
print("\nClass type of query image:", class_type)

Loading images from:  ./img_test/train

Class type of query image: america-beautiful
