# Model Inference using SPCS

Documentation: https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container

#### Upgrade `snowflake-ml-python` package

In [None]:
! pip install snowflake-ml-python --upgrade -q

In [None]:
# Import python packages
import json

import pandas as pd
import requests
import transformers

import snowflake.connector
from snowflake.ml import version
from snowflake.ml.registry import registry as registry_module
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions

from snowflake.snowpark.context import get_active_session

session = get_active_session()
print("Snowflake ML version: ", version.VERSION)

#### Create a transformer pipeline model

In [None]:
llama_3_model = transformers.pipeline(
    model="meta-llama/Llama-3.1-8B-Instruct",
    task="text-generation",
    # TODO: Add your token here
    token="hf_...",
    device_map="auto",
)

llama_3_model

In [None]:
registry = registry_module.Registry(session=session)
registry

In [None]:
mv = registry.log_model(
    model=llama_3_model,
    model_name="llama_3",
    version_name="V1"
    target_platforms=["SNOWPARK_CONTAINER_SERVICES"],
)
mv

#### Create a service from the logged model

In [None]:
mv.create_service(
    service_name="llama_3_service",
    # TODO: Add your image repo here
    image_repo="<your-image-repo>",
    # TODO: Add your compute pool here
    service_compute_pool="<your-compute-pool>",
    # TODO: Modify number of GPUs here
    gpu_requests="1",
    ingress_enabled=True,
)

In [None]:
# List all services in a compute pool
session.sql("SHOW SERVICES IN COMPUTE POOL <your-compute-pool>").collect()

In [None]:
# List all endpoints in a service
session.sql("SHOW ENDPOINTS IN SERVICE LLAMA_3_SERVICE").collect()

In [None]:
mv = registry.get_model("llama_3").version("V1")
mv

#### Call the service function of the model

In [None]:
x = [
    [
        {"role": "system", "content": "You are an helpful assistant."},
        {"role": "user", "content": "What is the capital of France?"},
    ]
]

x_df = pd.DataFrame([x], columns=["inputs"])
x_df

In [None]:
output_df = mv.run(
    X=x_df,
    function_name="__call__",
    service_name="LLAMA_3_SERVICE",
)
output_df

In [None]:
output_df.iloc[0][0]

#### Invoke the inference using REST API

In [None]:
def initiate_snowflake_connection():
    connection_parameters = SnowflakeLoginOptions()
    connection_parameters["session_parameters"] = {
        "PYTHON_CONNECTOR_QUERY_RESULT_FORMAT": "json"
    }
    snowflake_conn = snowflake.connector.connect(**connection_parameters)
    return snowflake_conn


def get_headers(snowflake_conn):
    token = snowflake_conn._rest._token_request("ISSUE")
    headers = {"Authorization": f'Snowflake Token="{token["data"]["sessionToken"]}"'}
    return headers


snowflake_conn = initiate_snowflake_connection()

In [None]:
# TODO: change the url to the service ingress url
# this can be found in the ""SHOW ENDPOINTS IN SERVICE LLAMA_3_SERVICE" sql query output above
url = "http://<ingress-url>/--call--"

response = requests.post(
    url,
    json={"data": x},
    headers=get_headers(snowflake_conn),
    timeout=15,
)

response.text