# Sentence Transformer Batch Inference Quickstart

This notebook demonstrates how to perform batch inference using a Sentence Transformer model in Snowflake. It walks through the complete workflow from model deployment to generating embeddings for text data.

## Overview

Sentence Transformers are models that convert text into dense vector embeddings, useful for semantic search, clustering, and similarity comparisons. This quickstart shows how to:

1. Set up Snowflake resources (database, schema, stage, compute pool)
2. Load and log a pre-trained Sentence Transformer model
3. Create an input dataset with text sentences
4. Run batch inference to generate embeddings
5. Inspect and output and load it into a table
6. Clean up resources

## Prerequisites

- Snowflake account with appropriate privileges
- `snowflake-ml-python>=1.26.0` (for batch inference support)
- `sentence-transformers==5.1.1`
- `numpy==1.26.4`
- A valid Snowflake connection configuration
- (Optional) A Snowflake stage and a compute pool. 
- (Optional) A Sentence Transformer model logged in Snowflake registry.
- (Optional) A test dataset.

## Running the Notebook

Run the cells in order to run the quickstart guide. You might need to skip some cells if you already have resources created and model logged.

# Install the Requeried Dependency

In [3]:
# uncomment to install the packages
# restart the session after installing the packages

# ! pip install sentence-transformers==5.1.1 --upgrade
# ! pip install numpy==1.26.4 --upgrade

# Make Sure the snowflake-ml-python has the Right Version

In [10]:
from importlib.metadata import version
# batch inference PuPr in snowflake-ml-python>=1.26.0
print(version('snowflake-ml-python'))

1.26.0


# Establish a Connection

In [11]:
from snowflake.snowpark import Session
session = Session.builder.config("connection_name", "preprod8").create()

# Create Resources

If you already have a stage and compute pool, please fill in the DB_NAME, SCHEMA_NAME, STAGE_NAME, and COMPUTE_POOL_NAME and skip the following cell.

Otherwise, please leave them as is and run the cells to create the resources.

In [12]:
DB_NAME = "BATCH_INFERENCE_QUICKSTART_SENTENCE_TRANSFORMER_DB"
SCHEMA_NAME = "PUBLIC"
STAGE_NAME = "BATCH_INFERENCE_QUICKSTART_STAGE"
COMPUTE_POOL_NAME = "BATCH_INFERENCE_QUICKSTART_SENTENCE_TRANSFORMER_COMPUTE_POOL"

In [13]:
# Create database
session.sql(f"CREATE DATABASE IF NOT EXISTS {DB_NAME}").collect()

# Create schema
session.sql(f"CREATE SCHEMA IF NOT EXISTS {DB_NAME}.{SCHEMA_NAME}").collect()

# Create stage with SSE encryption
session.sql(f"""
    CREATE STAGE IF NOT EXISTS {DB_NAME}.{SCHEMA_NAME}.{STAGE_NAME}
    ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE')
""").collect()

# Create compute pool with smallest CPU tier
# use GPU_NV_S for GPU workloads
session.sql(f"""
    CREATE COMPUTE POOL IF NOT EXISTS {COMPUTE_POOL_NAME}
    MIN_NODES = 1
    MAX_NODES = 2
    INSTANCE_FAMILY = CPU_X64_XS
""").collect()

# Set the session to use the newly created database and schema
session.use_database(DB_NAME)
session.use_schema(SCHEMA_NAME)

print(f"Created database: {DB_NAME}")
print(f"Created schema: {DB_NAME}.{SCHEMA_NAME}")
print(f"Created stage: {DB_NAME}.{SCHEMA_NAME}.{STAGE_NAME}")
print(f"Created compute pool: {COMPUTE_POOL_NAME}")

Created database: BATCH_INFERENCE_QUICKSTART_SENTENCE_TRANSFORMER_DB
Created schema: BATCH_INFERENCE_QUICKSTART_SENTENCE_TRANSFORMER_DB.PUBLIC
Created stage: BATCH_INFERENCE_QUICKSTART_SENTENCE_TRANSFORMER_DB.PUBLIC.BATCH_INFERENCE_QUICKSTART_STAGE
Created compute pool: BATCH_INFERENCE_QUICKSTART_SENTENCE_TRANSFORMER_COMPUTE_POOL


### If you don't already have a model logged, load and log the SentenceTransformer model

In [14]:
from sentence_transformers import SentenceTransformer
from snowflake.ml.registry import registry

sample_input_data = [
    "This is the first sentence.",
    "Here's another sentence for testing.",
]

reg = registry.Registry(session=session, database_name=DB_NAME, schema_name=SCHEMA_NAME)

embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

mv = reg.log_model(
    embed_model,
    model_name="sentence_transformer_minilm",
    sample_input_data=sample_input_data,
    pip_requirements=[
        "numpy==1.26.4",
        "sentence-transformers==5.1.1",
        "torch==2.9.0",
        "transformers==4.57.1",
    ],
)

  from .autonotebook import tqdm as notebook_tqdm


Logging model: validating model and dependencies...:   0%|          | 0/6 [00:00<?, ?it/s]



Logging model: creating model manifest...:  33%|███▎      | 2/6 [00:33<01:07, 16.79s/it]  

  self._warn_once(


Model logged successfully.: 100%|██████████| 6/6 [00:51<00:00,  8.56s/it]                          


### If you already have a model logged, please fill in MODEL_NAME and VERSION_NAME

In [None]:
from snowflake.ml.registry import registry

MODEL_NAME = "MODEL_NAME"
VERSION_NAME = "V1"

reg = registry.Registry(session=session, database_name=DB_NAME, schema_name=SCHEMA_NAME)
mv = reg.get_model(MODEL_NAME).get_version(VERSION_NAME)

# Create/Load Inference Input Dataset

In [15]:
import pandas as pd

# Define the data for the DataFrame
# The first column is the text to be embedded, the second column is the pass through ID column to map back to the input.
data = [
    ("The quick brown fox jumps over the lazy dog.", "a1b2c3d4-e5f6-7890-1234-567890abcdef"),
    ("Snowpark is a great library for data processing.", "f9e8d7c6-b5a4-3210-fedc-ba9876543210"),
    ("Python is a versatile programming language.", "1a2b3c4d-5e6f-7080-9101-112131415161")
]

# Define the column names and data types
columns = ["input_feature_0", "ID"]
schema = ["input_feature_0 VARCHAR", "ID VARCHAR(36)"]

# Create a pandas DataFrame
pandas_df = pd.DataFrame(data, columns=columns)

# Create the Snowpark DataFrame from the pandas DataFrame
X = session.create_dataframe(pandas_df, schema=schema)

# show the first row of the input dataset to verify the data is loaded correctly
X.show(1)

  return func(*args, **kwargs)


---------------------------------------------------------------------------------------
|"input_feature_0"                             |"ID"                                  |
---------------------------------------------------------------------------------------
|The quick brown fox jumps over the lazy dog.  |a1b2c3d4-e5f6-7890-1234-567890abcdef  |
---------------------------------------------------------------------------------------



# Run the Batch Inference Job

In [None]:
from snowflake.ml.model import JobSpec, OutputSpec, SaveMode

output_stage_location = f"@{DB_NAME}.{SCHEMA_NAME}.{STAGE_NAME}/output/"

job = mv.run_batch(
    X=X,
    compute_pool=COMPUTE_POOL_NAME,
    output_spec=OutputSpec(stage_location=output_stage_location, mode=SaveMode.OVERWRITE),
    job_spec=JobSpec(function_name="encode"),
)

job.wait() # Wait for the job to complete

# Inspect the Inference Output

In [17]:
session.sql(f'LS {output_stage_location}').show()
# The parquet files are the result of the inference job.
# The "_SUCCESS" file is a marker file that indicates the job has completed successfully.

---------------------------------------------------------------------------------------------------------------------------------
|"name"                                              |"size"  |"md5"                             |"last_modified"               |
---------------------------------------------------------------------------------------------------------------------------------
|batch_inference_quickstart_stage/output/3_63012...  |11627   |27033a315883bacf5e508df8af3a3daa  |Mon, 2 Feb 2026 23:05:07 GMT  |
|batch_inference_quickstart_stage/output/_SUCCESS    |0       |d41d8cd98f00b204e9800998ecf8427e  |Mon, 2 Feb 2026 23:05:07 GMT  |
---------------------------------------------------------------------------------------------------------------------------------



In [18]:
session.read.option("pattern", ".*\\.parquet").parquet(output_stage_location).show(1, max_width=200)

-------------------------------------------------------------------------------------------------------------------
|"input_feature_0"                             |"ID"                                  |"output_feature_0"         |
-------------------------------------------------------------------------------------------------------------------
|The quick brown fox jumps over the lazy dog.  |a1b2c3d4-e5f6-7890-1234-567890abcdef  |[                          |
|                                              |                                      |  4.393358901143074e-02,   |
|                                              |                                      |  5.893439799547195e-02,   |
|                                              |                                      |  4.817844927310944e-02,   |
|                                              |                                      |  7.754809409379959e-02,   |
|                                              |                        

## Copy the Output Stage Files into a Table

In [19]:
output_table = "batch_inference_output_table"
session.read.option("pattern", ".*\\.parquet").parquet(output_stage_location).write.mode("overwrite").saveAsTable(output_table)
# show the first row of the output table to verify the data is loaded correctly
session.table(f'{DB_NAME}.{SCHEMA_NAME}.{output_table}').show(1) 

-------------------------------------------------------------------------------------------------------------------
|"input_feature_0"                             |"ID"                                  |"output_feature_0"         |
-------------------------------------------------------------------------------------------------------------------
|The quick brown fox jumps over the lazy dog.  |a1b2c3d4-e5f6-7890-1234-567890abcdef  |[                          |
|                                              |                                      |  4.393358901143074e-02,   |
|                                              |                                      |  5.893439799547195e-02,   |
|                                              |                                      |  4.817844927310944e-02,   |
|                                              |                                      |  7.754809409379959e-02,   |
|                                              |                        

# Clean Up

In [None]:
# uncomment to clean up the database and compute pool
# session.sql(f'DROP DATABASE IF EXISTS {DB_NAME}').collect()
# session.sql(f'DROP COMPUTE POOL IF EXISTS {COMPUTE_POOL_NAME}').collect()

[Row(status='BATCH_INFERENCE_QUICKSTART_SENTENCE_TRANSFORMER_COMPUTE_POOL successfully dropped.')]