In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models
from torchvision.models.feature_extraction import create_feature_extractor
from PIL import Image
import psycopg
from pgvector.psycopg import register_vector
import os
import numpy as np

In [None]:
DEVICE = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
CHECKPOINT_PATH = '../checkpoints/removed-items-classification-vit/best_model_2024-05-09_03-45-39_epoch_5.pt'
INPUT_SIZE = 224

In [2]:
class ItemsDatabaseModule():
    def __init__(self, dbname="postgres", user="mac", host="localhost", password="DiDi3112!", port=5432):
        self.conn = psycopg.connect(dbname=dbname, user=user, host=host, password=password, port=port)
        self.conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
        register_vector(self.conn)
        self.cursor = self.conn.cursor()
        self.table_name = None
        
    def is_connected(self):
        try:
            self.cursor.execute('SELECT 1')
            print("Database connection established")
            return True
        except (Exception, psycopg.DatabaseError) as error:
            print(error)
            return False
        
    def create_index(self, index_name='vector_l2_ops'):
        self.cursor.execute(f'CREATE INDEX ON {self.table_name} using hnsw (embedding {index_name})')
        self.conn.commit()
        
    def query(self, query):
        self.cursor.execute(query)
        return self.cursor.fetchall()
    
    def close(self):
        self.cursor.close()
        self.conn.close()
        
    def create_table(self, table_name, embedding_size=256):
        self.table_name = table_name
        query = f"""
            CREATE TABLE IF NOT EXISTS {self.table_name} (
                id SERIAL PRIMARY KEY, 
                image TEXT, 
                category TEXT,
                embedding vector({embedding_size})
            );
        """
        self.cursor.execute(query)
        self.conn.commit()
        
    def get_table_name(self):
        return self.table_name
        
    def insert_value(self, image, category, embedding):
        query = f"""
            INSERT INTO {self.table_name} (image, category, embedding)
            VALUES (%s, %s, %s);
        """
        self.cursor.execute(query, (image, category, embedding))
        self.conn.commit()
        
    def drop_table(self, table_name):
        query = f"""
            DROP TABLE IF EXISTS {table_name};
        """
        self.cursor.execute(query)
        self.conn.commit()

In [7]:
itemsdb = ItemsDatabaseModule()
embed = itemsdb.query("""
    SELECT embedding
    FROM items
    LIMIT 5
""")
print(embed[0][0])
print(embed[1][0])

[ 1.1754481   0.58414304 -0.8735852  -0.80906224  0.21097395 -1.1972939
  0.24609399 -0.43785214  1.0880796  -1.1027856  -0.14107388 -0.2525432
 -1.2911388   1.2611641  -0.4131418   0.0339569  -0.6074739   1.6441836
 -0.7509738  -0.93411064 -1.4016767  -0.52660024  0.07844225  0.71260494
  0.6081758   0.7527648  -0.55179316 -0.34099808  0.30202764  1.0226698
  0.00483921  0.3631719   0.26087612 -0.69598454 -0.17281029  0.5572995
 -0.9449818   0.12198666 -0.3130826  -1.0290401  -0.9818472  -0.14799157
  0.21498512 -0.7195488   0.35893244 -0.3050649   0.55742073 -1.1754392
  0.7344784  -0.2677799  -0.26615003  1.7122817   0.4783408  -0.22513391
  0.15803085 -0.6139762   0.8865382   0.7504543   1.3191124  -0.37802595
  0.6462114   0.68237543 -1.2145308  -1.8563029  -1.5383689   0.833838
 -1.1396888   0.7771813  -1.2953123  -1.7873651   1.0049953   1.3489013
  1.2226381   1.1603274   1.9299953  -0.36325225 -2.100444   -0.18625826
 -0.93621284 -1.8880004   1.0206652  -0.81788886  1.6037066 

In [None]:
distinct_combos = itemsdb.query("""
    SELECT embedding
    FROM items
    LIMIT 5
""")
print(type(distinct_combos))
print(distinct_combos)
print(distinct_combos[1][0])

In [None]:
items_embeddings = itemsdb.query(f"""
    SELECT embedding
    FROM items
    WHERE image = '{distinct_combos[0][0]}'
""")
print(items_embeddings[0][0].shape)

In [None]:
embeddings_matrix = np.empty(shape=(len(distinct_combos), 2560))

for idx, combo in enumerate(distinct_combos):
    vector = np.zeros(shape=2560)
    items_embeddings = itemsdb.query(f"""
        SELECT embedding
        FROM items
        WHERE image = '{combo[0]}'
    """)
    sub_idx = 0
    for embedding in items_embeddings:
        try:
            vector[sub_idx: sub_idx + 512] = embedding[0].copy()
            sub_idx += 512
        except Exception as e:
            continue
    embeddings_matrix[idx, :] = vector.copy()
    print(f"Combo {combo[0]} saved!")

print(embeddings_matrix.shape)

In [None]:
np.save('positive_embedding.npy', embeddings_matrix, allow_pickle=False)

In [None]:
negative_embeddings_matrix = np.empty(shape=(len(distinct_combos), 2560))

for idx, combo in enumerate(distinct_combos):
    vector = np.zeros(shape=2560)
    items_embeddings = itemsdb.query(f"""
        SELECT embedding
        FROM items
        WHERE image = '{combo[0]}'
        LIMIT {len(combo)}
    """)
    sub_idx = 0
    for embedding in items_embeddings:
        try:
            vector[sub_idx: sub_idx + 512] = embedding[0].copy()
            sub_idx += 512
        except Exception as e:
            continue
    vector[sub_idx: sub_idx + 512] = itemsdb.query(f"""
        SELECT embedding
        FROM items
        WHERE image != '{combo[0]}'
        ORDER BY RANDOM()
        LIMIT 1
    """)[0][0].copy()
    negative_embeddings_matrix[idx, :] = vector.copy()
    print(f"Combo {combo[0]} saved!")

print(negative_embeddings_matrix.shape)

In [None]:
np.save('negative_embedding.npy', negative_embeddings_matrix, allow_pickle=False)