## An first approach to downloading Transformer models
  * may download model multiple times
  * possibly can report back via internet during inference as an outbound connection is still permited (per External Access network rule)

In [1]:
from snowflake.snowpark.session import Session
import snowflake.connector
import snowflake.snowpark.functions as F
import os
import json

In [2]:
with open('../config.json') as f:
    data = json.load(f)
    USERNAME = data['user']
    PASSWORD = data['password']
    SF_ACCOUNT = data['account']
    SF_ROLE = data['role']
    SF_WH = data['warehouse']
    SF_DB = data['database']
    SF_SCHEMA = data['schema']

connection_parameters = {
    "account": SF_ACCOUNT,
    "user": USERNAME,
    "password": PASSWORD,
    "role": SF_ROLE,
    "warehouse": SF_WH,
    "database": SF_DB,
    "schema": SF_SCHEMA
}

session = Session.builder.configs(connection_parameters).create()

con = snowflake.connector.connect(
    user=USERNAME, #You can get it by executing in UI: desc user <username>; 
    # Or Snowflake APP UI --> "HOME icon" | Profile | Username
    account=SF_ACCOUNT, #Add all of the account-name between https:// and snowflakecomputing.com
    password=PASSWORD,
    database=SF_DB,
    warehouse=SF_WH,
    role=SF_ROLE,
    schema=SF_SCHEMA
)

cur = con.cursor()

In [3]:
#sqlStatment = 'select current_account(), current_role(), current_user(), current_database(), current_schema(), current_warehouse(), current_version()'
sqlStatment = 'select  current_role(),  current_version()'
session.sql(sqlStatment).show()
sqlStatment = f"show warehouses like '{SF_WH}'"
session.sql(sqlStatment).collect()
sqlStatment = 'SELECT "type" as warehouseType FROM table(result_scan(last_query_id()))'
session.sql(sqlStatment).show()

------------------------------------------
|"CURRENT_ROLE()"  |"CURRENT_VERSION()"  |
------------------------------------------
|ACCOUNTADMIN      |7.31.0               |
------------------------------------------

----------------------
|"WAREHOUSETYPE"     |
----------------------
|SNOWPARK-OPTIMIZED  |
----------------------



## Start Inference via External Access with Local LLM (within Snowpark)

In [4]:
sql_statement =  '''CREATE or REPLACE NETWORK RULE external_access_llm_rule
MODE = EGRESS
TYPE = HOST_PORT
VALUE_LIST =( 'huggingface.co','cdn-lfs.huggingface.co', 'api-inference.huggingface.co')
'''

cur.execute(sql_statement)

<snowflake.connector.cursor.SnowflakeCursor at 0x7f99b8adbfd0>

In [5]:
sql_statement =   '''CREATE OR REPLACE EXTERNAL ACCESS INTEGRATION external_access_llm_integration
ALLOWED_NETWORK_RULES = (external_access_llm_rule)
ENABLED = true;
'''

cur.execute(sql_statement)

<snowflake.connector.cursor.SnowflakeCursor at 0x7f99b8adbfd0>

In [6]:
sql_statement =  '''CREATE OR REPLACE FUNCTION external_access_base_llm_infer(question VARCHAR)
RETURNS VARIANT
LANGUAGE PYTHON
runtime_version='3.8'
PACKAGES = ('transformers', 'pytorch', 'sentencepiece')
EXTERNAL_ACCESS_INTEGRATIONS = (external_access_llm_integration)
handler='base_llm'
as
$$
import _snowflake
import requests
import json
import os
import torch
os.environ['TRANSFORMERS_CACHE'] = '/tmp/'

modelID = "google/flan-t5-base" # 945 MB

from transformers import T5Tokenizer, T5ForConditionalGeneration

def base_llm(input_text):
  # model, tokenizer may be initialized several times
  tokenizer = T5Tokenizer.from_pretrained(modelID)
  model = T5ForConditionalGeneration.from_pretrained(modelID)
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids
  outputs = model.generate(input_ids ,max_length=50)

  response = tokenizer.decode(outputs[0])
  response = response.replace(tokenizer.pad_token, "").replace(tokenizer.eos_token, "")
  return (response.strip().upper())

$$;
'''

cur.execute(sql_statement)

<snowflake.connector.cursor.SnowflakeCursor at 0x7f99b8adbfd0>

In [7]:
sqlStatment = f"select 'translate English to French: What time is it??' as input, external_access_base_llm_infer( 'translate English to French: What time is it??' ) as llm_result"
session.sql(sqlStatment).show()

---------------------------------------------------------------------------
|"INPUT"                                         |"LLM_RESULT"            |
---------------------------------------------------------------------------
|translate English to French: What time is it??  |"QUELLE TEMPS EST-CE?"  |
---------------------------------------------------------------------------



In [8]:
prompt_df = session.table("FLAN_PROMPT")
print(f"Nbr of prompts: {prompt_df.count():,}")
#prompt_df.show(max_width=150)

Nbr of prompts: 14


In [9]:
prompt_df.select(F.col("PROMPT")).select(F.col("PROMPT"), F.call_function("external_access_base_llm_infer", F.col("PROMPT"))).show(15, max_width=150)

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"PROMPT"                                                                                                                                                |"EXTERNAL_ACCESS_BASE_LLM_INFER(""PROMPT"")"                                                        |
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|translate English to French: What time is it??                                                                                                          |"QUELLE TEMPS EST-CE?"                                                        

# Done - ZZZZZ