# IGDB Embeddings pgvector Demo - Get Embeddings from Model Endpoint & Import Embeddings into Datbase

In [1]:
!pip install tqdm "psycopg[binary]" pgvector --quiet

In [2]:
import sagemaker
import boto3
import pandas as pd
from sagemaker.huggingface.model import HuggingFacePredictor
from multiprocessing import cpu_count
from tqdm.contrib.concurrent import process_map
import json
import psycopg
from pgvector.psycopg import register_vector

In [3]:
# Create SageMaker Session
sess = sagemaker.Session()

# Create S3 client
s3 = boto3.client('s3')

# Create SecretsManager client
secretsmanager = boto3.client('secretsmanager')

## Retrive all variables from Notebook Instance's tags

In [4]:
# Get this Notebook Instance's ARN
NOTEBOOK_ARN = !jq '.ResourceArn' /opt/ml/metadata/resource-metadata.json --raw-output
NOTEBOOK_ARN = NOTEBOOK_ARN[0]

# Get the tags of this Notebook Instance
tags = sess.sagemaker_client.list_tags(ResourceArn=NOTEBOOK_ARN)['Tags']

# Filter out the keys that contains the necessary information
ASSETS_BUCKET = list(filter(lambda x: x['Key'] == 'VAR_ASSETS_BUCKET', tags))[0]['Value']
DB_SECRET_ARN = list(filter(lambda x: x['Key'] == 'VAR_DB_SECRET_ARN', tags))[0]['Value']
MODEL_ENDPOINT = list(filter(lambda x: x['Key'] == 'VAR_MODEL_ENDPOINT', tags))[0]['Value']

## Get Source Dataset from Bucket

In [5]:
# Download data file from bucket
s3.download_file(ASSETS_BUCKET, 'nintendo_switch_games.csv', './nintendo_switch_games.csv')

In [6]:
# Preview
games_df = pd.read_csv("./nintendo_switch_games.csv",)
games_df.head()

Unnamed: 0,igdb_id,name,summary,description,url,artwork_hash,screenshot_hash
0,174898,Clash of Chess,Our app is ideal for everyone. It contains 10 ...,"Title: ""Clash of Chess"" Summary: Our app is id...",https://www.igdb.com/games/clash-of-chess,ar15x4,sce3dq
1,186554,Minepull,More than just a puzzle game. Minepull highlig...,"Title: ""Minepull"" Summary: More than just a pu...",https://www.igdb.com/games/minepull,ar1c3t,scf5ib
2,186935,Biker Garage: Mechanic Simulator,Biker Garage: Mechanic Simulator allows you to...,"Title: ""Biker Garage: Mechanic Simulator"" Summ...",https://www.igdb.com/games/biker-garage-mechan...,ar1ca7,scf75j
3,187097,Geography Quiz Festival: Guess the Countries,Advance through our beloved game by completing...,"Title: ""Geography Quiz Festival: Guess the Cou...",https://www.igdb.com/games/geography-quiz-fest...,ar1ci3,scf812
4,187475,Chess: Clash of Kings,Our app is ideal for everyone. It contains 10 ...,"Title: ""Chess: Clash of Kings"" Summary: Our ap...",https://www.igdb.com/games/chess-clash-of-kings,ar1d21,scf9ms


## Get Embeddings from Model Inference Endpoint

In [7]:
# Get the existing model
predictor = HuggingFacePredictor(MODEL_ENDPOINT, sagemaker_session=sess)

In [8]:
# Inference
def generate_embeddings(description):
    data = {"inputs": description}
    prediction = predictor.predict(data=data)
    vector = prediction['vectors']
    return vector

workers = 1 * cpu_count()
chunksize = 32

# Make inferences
vectors = process_map(generate_embeddings, games_df['description'].tolist(), max_workers=workers, chunksize=chunksize)

# Write embeddings into Pandas DataFrame
games_df.loc[:, "description_embeddings"] = vectors
games_df.head()

  0%|          | 0/1468 [00:00<?, ?it/s]

Unnamed: 0,igdb_id,name,summary,description,url,artwork_hash,screenshot_hash,description_embeddings
0,174898,Clash of Chess,Our app is ideal for everyone. It contains 10 ...,"Title: ""Clash of Chess"" Summary: Our app is id...",https://www.igdb.com/games/clash-of-chess,ar15x4,sce3dq,"[-0.005498327314853668, -0.03673512861132622, ..."
1,186554,Minepull,More than just a puzzle game. Minepull highlig...,"Title: ""Minepull"" Summary: More than just a pu...",https://www.igdb.com/games/minepull,ar1c3t,scf5ib,"[0.02409840002655983, -0.025014473125338554, 0..."
2,186935,Biker Garage: Mechanic Simulator,Biker Garage: Mechanic Simulator allows you to...,"Title: ""Biker Garage: Mechanic Simulator"" Summ...",https://www.igdb.com/games/biker-garage-mechan...,ar1ca7,scf75j,"[-0.06617075949907303, 0.032370537519454956, -..."
3,187097,Geography Quiz Festival: Guess the Countries,Advance through our beloved game by completing...,"Title: ""Geography Quiz Festival: Guess the Cou...",https://www.igdb.com/games/geography-quiz-fest...,ar1ci3,scf812,"[0.1011415347456932, 0.0005400646477937698, 0...."
4,187475,Chess: Clash of Kings,Our app is ideal for everyone. It contains 10 ...,"Title: ""Chess: Clash of Kings"" Summary: Our ap...",https://www.igdb.com/games/chess-clash-of-kings,ar1d21,scf9ms,"[-0.025009868666529655, -0.01236047875136137, ..."


## Import Embeddings into PostgreSQL Database

In [11]:
# Get database credentials
db_secret = secretsmanager.get_secret_value(
    SecretId=DB_SECRET_ARN
)
db_secret_string = json.loads(db_secret['SecretString'])
db_secret_string
db_host = db_secret_string['host']
db_port = db_secret_string['port']
db_user = db_secret_string['username']
db_pass = db_secret_string['password']

In [12]:
# Connect to Database
with psycopg.connect(host=db_host, user=db_user, password=db_pass, port=db_port, connect_timeout=10, autocommit=True) as conn:
    with conn.cursor() as cur:
        # Enable pgvector extension
        cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
        register_vector(conn)

        # Create IGDB table
        cur.execute("DROP TABLE IF EXISTS igdb")
        cur.execute("""CREATE TABLE IF NOT EXISTS igdb(
                       igdb_id bigserial primary key, 
                       name text,
                       summary text,
                       description text,
                       url text,
                       artwork_hash text,
                       screenshot_hash text,
                       description_embeddings vector(384));""")

        # Insert data into IGDB table
        for _, row in games_df.iterrows():
            cur.execute("""INSERT INTO igdb
                              (igdb_id, name, summary, description, url, artwork_hash, screenshot_hash, description_embeddings) 
                          VALUES(%s, %s, %s, %s, %s, %s, %s, %s);""", 
                          (row["igdb_id"], row["name"], row["summary"], row["description"], row["url"], row["artwork_hash"], row["screenshot_hash"], row["description_embeddings"], ))
        
        # # Create L2 distance index
        # cur.execute("""CREATE INDEX ON igdb 
        #        USING ivfflat (description_embeddings vector_l2_ops) WITH (lists = 100);""")  # Index name: igdb_description_embeddings_idx
        # cur.execute("VACUUM ANALYZE igdb;")
        
        # Create Cosine distance index
        cur.execute("""CREATE INDEX ON igdb 
               USING ivfflat (description_embeddings vector_cosine_ops) WITH (lists = 100);""")  # Index name: igdb_description_embeddings_idx
        cur.execute("VACUUM ANALYZE igdb;")