In [None]:
# Import python packages
import streamlit as st
import pandas as pd

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()


In [None]:
! pip install sentence-transformers

**import libraries**

In [None]:
from sentence_transformers import SentenceTransformer
from typing import Any, Dict, List, Optional
import pandas as pd
import json 
import numpy as np 
import re
import sys
from tqdm import tqdm
import torch

**extract call parameters**

In [None]:
params = sys.argv

source_table = params[0]
output_table = params[1]
id_column = params[2]
columns_to_clean = params[3].split(',')
embedding_model = params[4]

print("sys.argv:", sys.argv)



**import the data**

In [None]:
  # Example table
columns = [id_column] + columns_to_clean

initial_df = session.table(source_table).to_pandas()
df_to_embed = initial_df[columns].copy()

**Clean the columns to embed**

In [None]:
#define function to clean text before embedding

def clean_columns_to_embed(text_to_clean: str, start_text: str) -> str:
    """
    Format text of the columns used for embedding and add a prefix to that text.

    Args:
        text_to_clean (str): The text to be cleaned and formatted.
        start_text (str): The prefix to add to the cleaned text.

    Returns:
        str: Cleaned and formatted text in the format "{start_text}: {cleaned_text}."
             Returns empty string if input is None or empty.
    """

    if text_to_clean is None or text_to_clean == "":
        return ""

    text = str(text_to_clean)

    text = re.sub(r"[\[\]'\"]", "", text)

    text = text.replace("|", ",")

    text = text.lower()

    text = re.sub(r"[^a-z0-9 .,?!]+", "", text)

    text = re.sub(r" +", " ", text)

    text = re.sub(r" ,", ",", text)

    text = text.strip()

    return f"{start_text}: {text}."

In [None]:
#cleant the dataframe and concat embedding columns

cleaned_df = df_to_embed.copy()

for col in columns_to_clean:
        start_text = f" recipe {col.lower()}"
        cleaned_df[col] = df_to_embed[col].apply(lambda text: clean_columns_to_embed(text, start_text))
        
cleaned_df["TEXT_TO_EMBED"] = cleaned_df[columns_to_clean].agg(" ".join, axis=1)

**embed the columns TEXT_TO_EMBED**

In [None]:
#load the model
model = SentenceTransformer(
                    embedding_model,
                    trust_remote_code=True,
                    device="cuda",
                )

In [None]:
#define function to compute embeddings
def compute_embedding_columns(
    df: pd.DataFrame,
    embedding_model: SentenceTransformer,
    name_embedding_column_input: str,
    name_embedding_column_output: str = "EMBEDDING",
    batch_size: int = 128,
) -> pd.DataFrame:
    """
    Create an embedding column by computing embeddings batch by batch.

    Args:
        df (pd.DataFrame): dataframe containing the data
        embedding_model (SentenceTransformer): model used to compute the embeddings
        name_embedding_column_input (str): column used as input text
        name_embedding_column_output (str): column to store embeddings
        batch_size (int): batch size for embedding computation

    Returns:
        pd.DataFrame: dataframe with the new embedding column
    """

    texts = df[name_embedding_column_input].tolist()
    all_embeddings = []

    # Use tqdm to show progress
    for start_idx in tqdm(range(0, len(texts), batch_size), desc="Computing embeddings"):
        batch_texts = texts[start_idx : start_idx + batch_size]

        batch_embeddings = embedding_model.encode(
            batch_texts,
            batch_size=batch_size,
            show_progress_bar=False,  # tqdm will show progress instead
            normalize_embeddings=True,
            convert_to_numpy=True,
        )

        all_embeddings.extend(batch_embeddings)

    # Convert embeddings to lists for Pandas/Snowflake
    df[name_embedding_column_output] = [emb.tolist() for emb in all_embeddings]

    return df

In [None]:
#compute embeddings
cleaned_df_with_embedding = compute_embedding_columns(
                cleaned_df,
                model,
                name_embedding_column_input="TEXT_TO_EMBED",
                name_embedding_column_output="EMBEDDING",
            )

**join with initial_df**

In [None]:
# join back the id column
embedding_df = cleaned_df_with_embedding[
    [id_column, "TEXT_TO_EMBED", "EMBEDDING"]
].copy()


final_df = initial_df.merge(
    embedding_df[[id_column, "TEXT_TO_EMBED", "EMBEDDING"]], 
    on=id_column, 
    how="left"
)

In [None]:
final_df.head()

In [None]:
final_df.dtypes

In [None]:
# Create a copy of the dataframe
df_for_snowflake = final_df.copy()

# Function to check if a value looks like an array
def is_array_like(val):
    if isinstance(val, (list, np.ndarray)):
        return True
    if val is None:
        return False
    try:
        if pd.isna(val):
            return False
    except (TypeError, ValueError):
        return isinstance(val, (list, np.ndarray))
    
    if isinstance(val, str):
        val_stripped = val.strip()
        return val_stripped.startswith('[') and val_stripped.endswith(']')
    return False

# Function to convert to proper array format
def convert_to_array(val):
    if val is None:
        return None
    if isinstance(val, (list, np.ndarray)):
        return list(val) if isinstance(val, np.ndarray) else val
    try:
        if pd.isna(val):
            return None
    except (TypeError, ValueError):
        pass
    
    if isinstance(val, str):
        try:
            parsed = json.loads(val)
            return parsed if isinstance(parsed, list) else None
        except:
            return None
    return None

# Get embedding dimension
embedding_dim = None
if 'EMBEDDING' in df_for_snowflake.columns:
    for val in df_for_snowflake['EMBEDDING'].dropna().head(1):
        try:
            if isinstance(val, (list, np.ndarray)):
                embedding_dim = len(val)
            elif isinstance(val, str):
                parsed = json.loads(val)
                embedding_dim = len(parsed)
            break
        except:
            continue

if embedding_dim:
    print(f"Detected EMBEDDING dimension: {embedding_dim}")

# Detect array columns
array_columns = []
for col in df_for_snowflake.columns:
    if col == 'EMBEDDING':  # Skip EMBEDDING, handle it separately
        continue
        
    sample_values = df_for_snowflake[col].head(10)
    
    if len(sample_values) > 0:
        array_count = 0
        valid_samples = 0
        
        for val in sample_values:
            try:
                if is_array_like(val):
                    array_count += 1
                valid_samples += 1
            except:
                continue
        
        if valid_samples > 0 and array_count > valid_samples / 2:
            print(f"Detected '{col}' as ARRAY type")
            array_columns.append(col)
            df_for_snowflake[col] = df_for_snowflake[col].apply(convert_to_array)

# Convert EMBEDDING column
if 'EMBEDDING' in df_for_snowflake.columns:
    df_for_snowflake['EMBEDDING'] = df_for_snowflake['EMBEDDING'].apply(convert_to_array)

# Build CREATE TABLE statement dynamically
column_definitions = []
for col in df_for_snowflake.columns:
    col_upper = col.upper()
    
    if col == 'EMBEDDING' and embedding_dim:
        col_type = f"VECTOR(FLOAT, {embedding_dim})"
    elif col in array_columns:
        col_type = "ARRAY"
    elif df_for_snowflake[col].dtype == 'object':
        col_type = "VARCHAR(16777216)"
    elif df_for_snowflake[col].dtype in ['int8', 'int16', 'int32', 'int64']:
        col_type = "NUMBER(38,0)"
    elif df_for_snowflake[col].dtype in ['float32', 'float64']:
        col_type = "FLOAT"
    elif 'date' in str(df_for_snowflake[col].dtype).lower():
        col_type = "DATE"
    else:
        col_type = "VARCHAR(16777216)"
    
    column_definitions.append(f"{col_upper} {col_type}")

**convert embedding to vector format**

In [None]:
# Drop table if exists and create with proper schema
session.sql(f"DROP TABLE IF EXISTS {output_table}").collect()

create_table_sql = f"""
CREATE TABLE {output_table} (
    {', '.join(column_definitions)}
)
"""

print("\nCreating table with schema:")
print(create_table_sql)

session.sql(create_table_sql).collect()

# Create a temporary table without VECTOR column
temp_table = f"{output_table}_TEMP"
df_temp = df_for_snowflake.drop(columns=['EMBEDDING'])

snowpark_df_temp = session.create_dataframe(df_temp)
snowpark_df_temp.write.mode("overwrite").save_as_table(temp_table)

# Insert data with VECTOR casting
insert_sql = f"""
INSERT INTO {output_table}
SELECT 
    {', '.join([col.upper() if col != 'EMBEDDING' else f"TO_VECTOR(EMBEDDING)::{embedding_dim}::VECTOR(FLOAT, {embedding_dim})" for col in df_for_snowflake.columns])}
FROM (
    SELECT t.*, e.EMBEDDING
    FROM {temp_table} t
    JOIN (SELECT ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) as rn, * FROM TABLE(FLATTEN(PARSE_JSON(?))))
) 
"""

# Alternative: Use simple column-by-column insert
cols_without_embedding = [col.upper() for col in df_for_snowflake.columns if col != 'EMBEDDING']

insert_sql = f"""
INSERT INTO {output_table} ({', '.join([col.upper() for col in df_for_snowflake.columns])})
SELECT {', '.join(cols_without_embedding)}, 
       {', '.join([col.upper() for col in df_for_snowflake.columns])}::VECTOR(FLOAT, {embedding_dim}) as EMBEDDING
FROM {temp_table}
"""

# Simpler approach: Write all data to temp, then copy with casting
session.sql(f"DROP TABLE IF EXISTS {temp_table}").collect()

# Write full dataframe to temp table
snowpark_df = session.create_dataframe(df_for_snowflake)
snowpark_df.write.mode("overwrite").save_as_table(temp_table)

# Copy data with proper VECTOR casting
all_cols = [col.upper() for col in df_for_snowflake.columns]
select_cols = [f"{col}::VECTOR(FLOAT, {embedding_dim})" if col == 'EMBEDDING' else col for col in all_cols]

insert_sql = f"""
INSERT INTO {output_table}
SELECT {', '.join(select_cols)}
FROM {temp_table}
"""

print(f"\nInserting data with VECTOR casting...")
session.sql(insert_sql).collect()

# Clean up temp table
session.sql(f"DROP TABLE IF EXISTS {temp_table}").collect()

print(f"\nTable {output_table} created successfully!")

# Show the schema to verify
print("\nTable Schema:")
session.sql(f"DESCRIBE TABLE {output_table}").show()

# Verify row count
result = session.sql(f"SELECT COUNT(*) as row_count FROM {output_table}").collect()
print(f"\nRows inserted: {result[0]['ROW_COUNT']}")