In [55]:
import requests
import snowflake.connector

HOST = "FRHSAXD-CFB18747.snowflakecomputing.com"
API_ENDPOINT = "/api/v2/cortex/agent:run"
API_TIMEOUT = 50000  # in milliseconds

CORTEX_SEARCH_SERVICES = "DASH_DB.DASH_SCHEMA.VEHICLES_INFO"
SEMANTIC_MODELS_SUPPLY_CHAIN = "@DASH_DB.DASH_SCHEMA.DASH_SEMANTIC_MODELS/supply_chain_semantic_model.yaml"
SEMANTIC_MODELS_SUPPORT_TICKETS = "@DASH_DB.DASH_SCHEMA.DASH_SEMANTIC_MODELS/support_tickets_semantic_model.yaml"

In [None]:
# Authenticate with private key
snow_conn = snowflake.connector.connect(
    account="FRHSAXD-CFB18747",
    user="ANGELCORTEX2025",
    private_key_file="/home/angel/sf_keys/rsa_key.p8",
    port=443,
    warehouse="COMPUTE_WH",
    role="ACCOUNTADMIN",
)

with snow_conn.cursor() as cur:
    df = cur.execute("SELECT CURRENT_ACCOUNT(), CURRENT_ROLE(), CURRENT_USER()").fetch_pandas_all()

df

Unnamed: 0,CURRENT_ACCOUNT(),CURRENT_ROLE(),CURRENT_USER()
0,XIB99769,ACCOUNTADMIN,ANGELCORTEX2025


In [45]:
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.hazmat.primitives.serialization import PublicFormat
from cryptography.hazmat.backends import default_backend

import base64
from getpass import getpass
import hashlib
# If you generated an encrypted private key, implement this method to return
# the passphrase for decrypting your private key. As an example, this function
# prompts the user for the passphrase.
def get_private_key_passphrase():
    return getpass('Passphrase for private key: ')

private_key = "/home/angel/sf_keys/rsa_key.p8"

# Open the private key file.
# Replace <private_key_file_path> with the path to your private key file (e.g. /x/y/z/rsa_key.p8).
with open(private_key, 'rb') as pem_in:
    pemlines = pem_in.read()
    try:
        # Try to access the private key without a passphrase.
        private_key = load_pem_private_key(pemlines, None, default_backend())
    except TypeError:
        # If that fails, provide the passphrase returned from get_private_key_passphrase().
        private_key = load_pem_private_key(pemlines, get_private_key_passphrase().encode(), default_backend())

# Get the raw bytes of the public key.
public_key_raw = private_key.public_key().public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo)

# Get the sha256 hash of the raw bytes.
sha256hash = hashlib.sha256()
sha256hash.update(public_key_raw)

# Base64-encode the value and prepend the prefix 'SHA256:'.
public_key_fp = 'SHA256:' + base64.b64encode(sha256hash.digest()).decode('utf-8')
public_key_fp

'SHA256:7KzO++ERgIBiUQc62xv1guqtW/cI39wXjIwHJyCjIkQ='

In [None]:
from datetime import timedelta, timezone, datetime

# This example relies on the PyJWT module (https://pypi.org/project/PyJWT/).
import jwt

# Construct the fully qualified name of the user in uppercase.
# - Replace <account_identifier> with your account identifier.
#   (See https://docs.snowflake.com/en/user-guide/admin-account-identifier.html .)
# - Replace <user_name> with your Snowflake user name.
account = "FRHSAXD-CFB18747"

# Use uppercase for the account identifier and user name.
account = account.upper()
user = "ANGELCORTEX2025".upper()
qualified_username = account + "." + user

# Get the current time in order to specify the time when the JWT was issued and the expiration time of the JWT.
now = datetime.now(timezone.utc)

# Specify the length of time during which the JWT will be valid. You can specify at most 1 hour.
lifetime = timedelta(minutes=59)

# Create the payload for the token.
payload = {

    # Set the issuer to the fully qualified username concatenated with the public key fingerprint (calculated in the  previous step).
    "iss": qualified_username + '.' + public_key_fp,

    # Set the subject to the fully qualified username.
    "sub": qualified_username,

    # Set the issue time to now.
    "iat": now,

    # Set the expiration time, based on the lifetime specified for this object.
    "exp": now + lifetime
}

# Generate the JWT. private_key is the private key that you read from the private key file in the previous step when you generated the public key fingerprint.
encoding_algorithm="RS256"
token = jwt.encode(payload, key=private_key, algorithm=encoding_algorithm)

# If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string, rather than a string.
# If the token is a byte string, convert it to a string.
if isinstance(token, bytes):
  token = token.decode('utf-8')
decoded_token = jwt.decode(token, key=private_key.public_key(), algorithms=[encoding_algorithm])
print("Generated a JWT with the following payload:\n{}".format(decoded_token))

Generated a JWT with the following payload:
{'iss': 'FRHSAXD-CFB18747.ANGELCORTEX2025.SHA256:7KzO++ERgIBiUQc62xv1guqtW/cI39wXjIwHJyCjIkQ=', 'sub': 'FRHSAXD-CFB18747.ANGELCORTEX2025', 'iat': 1742953288, 'exp': 1742956828}


In [56]:
payload = {
    "model": "llama3.1-70b",
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Can you show me a breakdown of customer support tickets by service type cellular vs business internet?"
                }
            ]
        }
    ],
   "tools": [
            { "tool_spec": { "type": "cortex_search", "name": "vehicles_info_search" } },
            { "tool_spec": { "type": "cortex_analyst_text_to_sql", "name": "support" } },
            { "tool_spec": { "type": "cortex_analyst_text_to_sql", "name": "supply_chain" } }
        ],
        "tool_resources": {
            "supply_chain": {"semantic_model_file": SEMANTIC_MODELS_SUPPLY_CHAIN},
            "support": {"semantic_model_file": SEMANTIC_MODELS_SUPPORT_TICKETS},
            "vehicles_info_search": {
                "name": CORTEX_SEARCH_SERVICES,
                "max_results": 10,
                "title_column": "title",
                "id_column": "relative_path"
            }
        }
}
    
resp = requests.post(
    url=f"https://{HOST}/{API_ENDPOINT}",
    json=payload,
    headers={
        'X-Snowflake-Authorization-Token-Type': 'KEYPAIR_JWT',
        "Content-Type": "application/json",
        "Accept": "application/json",
        "Authorization": f'Bearer {token}', 
    },
)

In [61]:
print(resp.content.decode("utf-8"))

event: message.delta
data: {"id":"msg_001","object":"message.delta","delta":{"content":[{"index":0,"type":"tool_use","tool_use":{"tool_use_id":"toolu_5bdcd3c0","name":"support","input":{"messages":["role:USER content:{text:{text:\"Can you show me a breakdown of customer support tickets by service type cellular vs business internet?\"}}"],"model":"snowflake-hosted-semantic","experimental":""}}},{"index":0,"type":"tool_results","tool_results":{"tool_use_id":"toolu_5bdcd3c0","content":[{"type":"json","json":{"text":"This is our interpretation of your question:\n\nShow me the count of support tickets for each service type, specifically comparing Cellular and Business Internet services","suggestions":[],"sql":"WITH __support_tickets AS (\n  SELECT\n    ticket_id,\n    service_type\n  FROM dash_db.dash_schema.support_tickets\n)\nSELECT\n  service_type,\n  COUNT(DISTINCT ticket_id) AS ticket_count\nFROM __support_tickets\nWHERE\n  service_type IN (\u0027Cellular\u0027, \u0027Business Internet

In [50]:
import json

def parse_response_to_events(response):
    """Parse raw SSE response into a list of events"""
    events = []
    
    # Convert bytes to string if necessary
    if isinstance(response, bytes):
        response = response.decode("utf-8")
    
    # Split the response into lines
    lines = response.splitlines()
    
    # Process each line to extract events
    for line in lines:
        if line.startswith("event:"):
            # Extract the event type
            event_type = line[len("event: "):].strip()
        elif line.startswith("data:"):
            # Extract the data and parse it as JSON
            try:
                data = json.loads(line[len("data: "):].strip())
                events.append({"event": event_type, "data": data})
            except json.JSONDecodeError:
                print(f"Invalid JSON in line: {line}")
    
    return events

def process_sse_response(response):
    """Process SSE response"""
    text = ""
    sql = ""
    citations = []
    
    if not response:
        return text, sql, citations
    if isinstance(response, str):
        return text, sql, citations
    try:
        for event in response:
            if event.get('event') == "message.delta":
                data = event.get('data', {})
                delta = data.get('delta', {})
                
                for content_item in delta.get('content', []):
                    content_type = content_item.get('type')
                    if content_type == "tool_results":
                        tool_results = content_item.get('tool_results', {})
                        if 'content' in tool_results:
                            for result in tool_results['content']:
                                if result.get('type') == 'json':
                                    text += result.get('json', {}).get('text', '')
                                    search_results = result.get('json', {}).get('searchResults', [])
                                    for search_result in search_results:
                                        citations.append({'source_id':search_result.get('source_id',''), 'doc_id':search_result.get('doc_id', '')})
                                    sql = result.get('json', {}).get('sql', '')
                    if content_type == 'text':
                        text += content_item.get('text', '')
                            
    except json.JSONDecodeError as e:
        print(f"Error processing events: {str(e)}")
                
    except Exception as e:
        print(f"Error processing events: {str(e)}")
        
    return text, sql, citations

In [52]:
# Example usage
raw_response = resp.content  # Assuming this is the raw SSE response
events = parse_response_to_events(raw_response)

# Pass the parsed events to the process_sse_response function
text, sql, citations = process_sse_response(events)

# Print the results
print("Text:", text)
print("SQL:", sql)
print("Citations:", citations)

Invalid JSON in line: data: [DONE]
Text: I don't know the answer to that question.
SQL: 
Citations: [{'source_id': 1.0, 'doc_id': 'CONV005'}, {'source_id': 2.0, 'doc_id': 'CONV001'}, {'source_id': 3.0, 'doc_id': 'CONV004'}, {'source_id': 4.0, 'doc_id': 'CONV007'}, {'source_id': 5.0, 'doc_id': 'CONV002'}, {'source_id': 6.0, 'doc_id': 'CONV010'}, {'source_id': 7.0, 'doc_id': 'CONV009'}, {'source_id': 8.0, 'doc_id': 'CONV008'}, {'source_id': 9.0, 'doc_id': 'CONV006'}, {'source_id': 10.0, 'doc_id': 'CONV003'}]


In [59]:
events

[{'event': 'message.delta',
  'data': {'id': 'msg_001',
   'object': 'message.delta',
   'delta': {'content': [{'index': 0,
      'type': 'tool_use',
      'tool_use': {'tool_use_id': 'toolu_a719d480',
       'name': 'search1',
       'input': {'scoringConfig': '<nil>',
        'experimentalJson': '',
        'query': 'Can you show me a breakdown of customer support tickets by service type cellular vs business internet?',
        'columns': ['conversation_id'],
        'filters': '',
        'limit': 10.0,
        'requestId': '252e2ed3-129c-4ca1-b2bc-4791ebcc4d83'}}},
     {'index': 0,
      'type': 'tool_results',
      'tool_results': {'tool_use_id': 'toolu_a719d480',
       'content': [{'type': 'json',
         'json': {'searchResults': [{'doc_title': '',
            'text': "In-depth demo session with DataDriven Co's Analytics team and Business Intelligence managers. Showcase focused on advanced analytics capabilities, custom dashboard creation, and real-time data processing featu