# <b><center>Vector Index Creation</center></b>

  

This notebook demonstrates how to:
1. Connect to Presto/Iceberg
2. Generate vector embeddings from text data
3. Store embeddings in an Iceberg table
4. Create a vector index for fast similarity search

**Prerequisites:**

- Presto server running on local/vm/cpd/saas
- Iceberg catalog should be created and the bucket needs to be associated
- The test data must be in CSV format and accessible either from the working directory or via a specified file path
- To install required python packages run "pip install -r requirements.txt"

For full breakdown please refer to the [README_NOTEBOOKS.md](README_NOTEBOOKS.md)   


**Expected Output:**
- Generated embeddings are stored in iceberg table.
- Vector index created for fast retrieval.

#### <b> Configuration & Setup

In [None]:

# === CONFIGURATION ===
# Presto Connection
HOST = '' #host engine address
PORT = 8080
# Authentication (Optional)
HTTP_SCHEME = 'http'  # or 'https'
USER = '' #host instance username
PASSWORD = ''  # Leave empty if not using auth
DISABLE_SSL_VERIFICATION = True  # Only for dev with self-signed certs
CATALOG = '' # Presto catalog name
SCHEMA = '' # Presto schema name
S3_LOCATION = f"s3a://<bucketname>/{SCHEMA}"
TABLE = ''  # Presto table name
# Table Schema Configuration (CUSTOMIZABLE)
TEXT_COLUMN = ''  # Name for text column in table
EMBEDDING_COLUMN = ''  # Name for embedding column in table

# Data Processing
CSV_FILE = "path/to/csv" # Path to the csv file
SOURCE_COLUMN = ''  # CHANGE THIS to the column you want to embed

OUTPUT_FILE = "embeddings.csv"
SQL_FILE = "inserts.sql"
MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'
NUM_ROWS = 100
BATCH_SIZE = 15  # Number of rows to process at a time. Recommended to keep this value below 15 for optimal performance/stability.

#### Initializing helper functions

In [None]:
# === IMPORTS ===
import prestodb
from prestodb.exceptions import PrestoUserError
from typing import Dict, Any, Tuple, List, Optional, cast
from sentence_transformers import SentenceTransformer
import pandas as pd
import numpy as np
import os
import csv
import json
import re
import textwrap
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


# === HELPER FUNCTIONS ===
def get_presto_connection(
    host: str = HOST,
    port: int = PORT,
    user: str = USER,
    catalog: str = CATALOG,
    schema: str = SCHEMA,
    http_scheme: str = HTTP_SCHEME,
    principal_id: str = USER,
    password: str = PASSWORD,
    disable_ssl_verification: bool = DISABLE_SSL_VERIFICATION
) -> prestodb.dbapi.Connection:
    """Create Presto connection with optional basic authentication and SSL verification control"""
    try:
        # Build connection parameters
        conn_params = {
            'host': host,
            'port': port,
            'user': user,
            'catalog': catalog,
            'schema': schema,
            'http_scheme': http_scheme,
        }
        
        # Add basic authentication if credentials provided
        if user and password:
            conn_params['auth'] = prestodb.auth.BasicAuthentication(user, password)
        
        conn = prestodb.dbapi.connect(**conn_params)
        
        # Disable SSL verification if requested (for self-signed certificates)
        if disable_ssl_verification and http_scheme == 'https':
            conn._http_session.verify = False
        
        print(f"Connected to {http_scheme}://{host}:{port}")
        return conn
    except Exception as e:
        raise PrestoUserError(f"Connection Error: {e}")

def execute_query(conn: prestodb.dbapi.Connection, sql: str, fetch: bool = False):
    cursor = None
    try:
        cursor = conn.cursor()
        cursor.execute(sql)
        if fetch:
            return cursor.fetchall(), cursor.description
        return None, None
    except prestodb.exceptions.PrestoUserError as e:
        print(f"Query Failed: {e}")
        raise
    finally:
        if cursor:
            cursor.close()

print(" Configuration and helper functions loaded")

#### <b> Test Connection

In [None]:
conn = None
try:
    conn = get_presto_connection()
    results, _ = execute_query(conn, "SELECT 1", fetch=True)
    
    if results:
        print("Connection test PASSED")
        print(f"Connected to catalog: {CATALOG}, schema: {SCHEMA}")
    else:
        print(" Connection test returned unexpected result")
        
except Exception as e:
    print(f"Connection test FAILED: {e}")
finally:
    if conn:
        conn.close()

#### <b> Generate Embeddings

In [None]:
# Load and preprocess data
if not os.path.exists(CSV_FILE):
    raise FileNotFoundError(f"Error: '{CSV_FILE}' not found")

df: pd.DataFrame = pd.read_csv(CSV_FILE, nrows=NUM_ROWS)

# Use single column from CSV
df = cast(pd.DataFrame, df[[SOURCE_COLUMN]].dropna().reset_index(drop=True))
df[TEXT_COLUMN] = df[SOURCE_COLUMN].astype(str)
df['row_id'] = df.index + 1
df = cast(pd.DataFrame, df[['row_id', TEXT_COLUMN]])


# Generate embeddings
model = SentenceTransformer(MODEL_NAME)
print(f"✓ Model '{MODEL_NAME}' loaded")
print(f"  Generating embeddings...")

# Generate embeddings
text_list = cast(List[str], df[TEXT_COLUMN].tolist())
embeddings = model.encode(
    text_list,
    show_progress_bar=True,
    normalize_embeddings=True
)

embedding_dim = len(embeddings[0])
print(f"✓ Generated {len(embeddings)} embeddings (dimension: {embedding_dim})")

# Save to CSV
df[EMBEDDING_COLUMN] = [emb.tolist() for emb in embeddings]
df.to_csv(OUTPUT_FILE, index=False)
print(f"✓ Saved embeddings to '{OUTPUT_FILE}'")

# Display sample
print(f"\nSample data (first 3 rows):")
print(cast(pd.DataFrame, df[['row_id', TEXT_COLUMN]].head(3)).to_string(index=False))


#### <b> Create schema

In [None]:
conn = get_presto_connection()

create_schema_sql = f"""
            CREATE SCHEMA IF NOT EXISTS {CATALOG}.{SCHEMA}
            WITH (
                location = '{S3_LOCATION}'
            )
        """

execute_query(conn, create_schema_sql, fetch=False)
    
print(f" SUCCESS: Schema '{SCHEMA}' created or already exists in '{CATALOG}'.")
print(f" Location: {S3_LOCATION}")

#### <b> Create Iceberg Table

In [None]:
FULL_TABLE_NAME = f"{CATALOG}.{SCHEMA}.{TABLE}"

CREATE_TABLE_DDL = f"""
CREATE TABLE IF NOT EXISTS {FULL_TABLE_NAME} (
    row_id      BIGINT,
    {TEXT_COLUMN}     VARCHAR,
    {EMBEDDING_COLUMN}   ARRAY(REAL)
)
"""

conn = None
try:
    conn = get_presto_connection()
    
    # Create table
    print(f"Creating table: {FULL_TABLE_NAME}")
    execute_query(conn, CREATE_TABLE_DDL, fetch=False)
    print("Table created successfully")
    
    # Validate with DESCRIBE
    describe_results, description = execute_query(
        conn, f"DESCRIBE {FULL_TABLE_NAME}", fetch=True
    )
    
    if describe_results:
        print(f"\nTable Schema:")
        print("-" * 50)
        for row in describe_results:
            print(f"  {row[0]:<15} {row[1]}")
        print("-" * 50)
        print(f"Table validated with {len(describe_results)} columns")
    
except Exception as e:
    print(f"Table creation failed: {e}")
finally:
    if conn:
        conn.close()

#### <b> Generate & Execute Inserts

In [None]:
# Generate SQL inserts
print(f"Generating batch INSERT statements (batch size: {BATCH_SIZE})...")
inserts = []
batch = []

with open(OUTPUT_FILE, newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    fieldnames = reader.fieldnames
    
    # Handle BOM if present
    if fieldnames and fieldnames[0].startswith('\ufeff'):
        fieldnames_list = list(fieldnames)
        fieldnames_list[0] = fieldnames_list[0].lstrip('\ufeff')
        reader.fieldnames = fieldnames
        reader.fieldnames = fieldnames_list
    
    for row in reader:
        row_id = row['row_id']
        # Enhanced escaping for SQL injection protection
        comment = row[TEXT_COLUMN].replace("'", "''").replace("\\", "\\\\")
        embedding_list = json.loads(row[EMBEDDING_COLUMN])
        embedding_array = "ARRAY[" + ",".join(map(str, embedding_list)) + "]"
        
        batch.append(f"({row_id}, '{comment}', CAST({embedding_array} AS ARRAY(REAL)))")
        
        if len(batch) >= BATCH_SIZE:
            inserts.append(
                f"INSERT INTO {CATALOG}.{SCHEMA}.{TABLE} (row_id, {TEXT_COLUMN}, {EMBEDDING_COLUMN}) "
                f"VALUES {', '.join(batch)}"
            )
            batch = []
    
    if batch:
        inserts.append(
            f"INSERT INTO {CATALOG}.{SCHEMA}.{TABLE} (row_id, {TEXT_COLUMN}, {EMBEDDING_COLUMN}) "
            f"VALUES {', '.join(batch)}"
        )

# Save SQL file
with open(SQL_FILE, 'w', encoding='utf-8') as f:
    f.write("\n".join(inserts))

print(f"Generated {len(inserts)} INSERT statements")
print(f"Saved to '{SQL_FILE}'")

# Execute inserts
print(f"\nExecuting INSERT statements...")
conn = None
cursor = None
try:
    conn = get_presto_connection()
    cursor = conn.cursor()
    
    for i, statement in enumerate(inserts, 1):
        print(f"  Executing batch {i}/{len(inserts)}...", end='\r')
        cursor.execute(statement)
    
    conn.commit()
    print(f"\nAll {len(inserts)} batches executed successfully")
    
    # Verify row count
    count_results, _ = execute_query(
        conn, f"SELECT COUNT(*) FROM {CATALOG}.{SCHEMA}.{TABLE}", fetch=True
    )
    if count_results:
        print(f"Table now contains {count_results[0][0]} rows")
    
except Exception as e:
    print(f"\nInsert execution failed: {e}")
    if conn:
        conn.rollback()
finally:
    if cursor is not None:  
        cursor.close()
    if conn:
        conn.close()

#### <b> Create Vector Index

In [None]:
INDEX_COMMAND = f"CALL {CATALOG}.system.CREATE_VEC_INDEX('{CATALOG}.{SCHEMA}.{TABLE}.{EMBEDDING_COLUMN}')"

conn = None
try:
    conn = get_presto_connection()
    cursor = conn.cursor()
    
    print(f"Creating vector index on {TABLE}.embedding...")
    cursor.execute(INDEX_COMMAND)
    conn.commit()
    
    print(" Vector index created successfully")
    print(" Table is now ready for similarity search")
    
except Exception as e:
    print(f" Index creation failed: {e}")
finally:
    if cursor is not None:
        cursor.close()
    if conn:
        conn.close()