### Regarding Warnings
* Warning in the notebook is due to the discrepancy in package version between local and server environment
* Readers can ignore them or update the packages to the required version through "pip install [package]==[version]"

### Import Necessary Packages

In [None]:
from snowflake.ml.registry import model_registry
from transformers import pipeline
from snowflake.ml.model.models import huggingface_pipeline
from snowflake.ml.model import deploy_platforms
from snowflake.snowpark import Session
import configparser

In [None]:
#%pip install transformers
#%pip install snowflake-ml-python

### Connect To Snowflake

In [2]:
# Loading Credentials From Config File
snowflake_credentials_file = '../snowflake_creds.config'
config = configparser.ConfigParser()
config.read(snowflake_credentials_file)
connection_parameters = dict(config['default'])

session = Session.builder.configs(connection_parameters).create()

### Load HuggingFace llama2 Model

In [3]:
HF_AUTH_TOKEN = "**************************"
registry = model_registry.ModelRegistry(session=session, database_name="SNOWPARK", schema_name="TUTORIAL", create_if_not_exists=True)
llama_model = huggingface_pipeline.HuggingFacePipelineModel(task="text-generation", model="meta-llama/Llama-2-7b-chat-hf", token=HF_AUTH_TOKEN, return_full_text=False, max_new_tokens=100)

The `snowflake.ml.registry.model_registry.ModelRegistry` has been deprecated starting from version 1.2.0.
It will stay in the Private Preview phase. For future implementations, kindly utilize `snowflake.ml.registry.Registry`,
except when specifically required. The old model registry will be removed once all its primary functionalities are
fully integrated into the new registry.
        
  registry = model_registry.ModelRegistry(session=session, database_name="SNOWPARK", schema_name="TUTORIAL", create_if_not_exists=True)
create_model_registry() is in private preview since 0.2.0. Do not use it in production. 


### Register llama2 Model

In [6]:
# MODEL_NAME = "LLAMA2_MODEL_7b_CHAT"
# MODEL_VERSION = "1"
# registry.delete_model( 
#     model_name=MODEL_NAME,
#   model_version=MODEL_VERSION,
# )

In [7]:
MODEL_NAME = "LLAMA2_MODEL_7b_CHAT"
MODEL_VERSION = "1"

llama_model=registry.log_model(
    model_name=MODEL_NAME,
  model_version=MODEL_VERSION,
    model=llama_model
)



In [8]:
registry.list_models().to_pandas()

Unnamed: 0,CREATION_CONTEXT,CREATION_ENVIRONMENT_SPEC,CREATION_ROLE,CREATION_TIME,ID,INPUT_SPEC,NAME,OUTPUT_SPEC,RUNTIME_ENVIRONMENT_SPEC,TYPE,URI,VERSION,ARTIFACT_IDS,DESCRIPTION,METRICS,TAGS,REGISTRATION_TIMESTAMP
0,,"{\n ""python"": ""3.8.16""\n}","""ACCOUNTADMIN""",2024-04-12 09:56:39.829000-07:00,980e6a80f8ed11eea87b34f39a51dc3f,,LLAMA2_MODEL_7b_CHAT,,,huggingface_pipeline,sfc://SNOWPARK.TUTORIAL.SNOWML_MODEL_980E6A80F...,1,,,,,2024-04-12 09:56:42.126000-07:00


### Deploy llama2 Model

In [9]:
llama_model.deploy(
  deployment_name="llama_predict",
  platform=deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES,
  options={
            "compute_pool": "snowpark_cs_compute_pool",
            "num_gpus": 1,
            "external_access_integrations": ["ALLOW_ALL_ACCESS_INTEGRATION"]
    }
)

### Load News Category JSON To as Snowflake Table

In [28]:
import pandas as pd
news_dataset = pd.read_json("../datasets/News_Category_Dataset_v3.json", lines=True).convert_dtypes()

NEWS_DATA_TABLE_NAME = "NEWS_DATASET"
news_dataset_sp_df = session.write_pandas(news_dataset,NEWS_DATA_TABLE_NAME,auto_create_table=True,quote_identifiers=False,overwrite=True)

In [29]:
news_dataset_sp_df.show()

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"LINK"                                              |"HEADLINE"                                          |"CATEGORY"      |"SHORT_DESCRIPTION"                                 |"AUTHORS"             |"DATE"               |
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|https://www.huffpost.com/entry/covid-boosters-u...  |Over 4 Million Americans Roll Up Sleeves For Om...  |U.S. NEWS       |Health experts said it is too early to predict ...  |Carla K. Johnson, AP  |1663891200000000000  |
|https://www.huffpost.com/entry/american-airline...  |American Airlines Flyer Charged, Banned For Lif...  |U

### Prompting & Prediction

In [31]:
import snowflake.snowpark.functions as F

prompt_prefix = """[INST] <>
Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema: {"properties": {"category": {"type": "string","description": "The category that the news should belong to."},"keywords": {"type": "array":"description": "The keywords that are mentioned in the news.","items": [{"type": "string"}]},"importance": {"type": "number","description": "A integer from 1 to 10 to show if the news is important. The higher the number, the more important the news is."}},"required": ["properties","keywords","importance"]}
 
As an example, input "Residents ordered to evacuate amid threat of growing wildfire in Washington state, medical facilities sheltering in place" results in the json: {"category": "Natural Disasters","keywords": ["evacuate", "wildfire", "Washington state", "medical facilities"],"importance": 8}
<>
"""
prompt_suffix = "[/INST]"

df_inputs = news_dataset_sp_df.with_column('"input"',F.concat_ws(F.lit(" "),F.lit(prompt_prefix),F.col('SHORT_DESCRIPTION'),F.lit(prompt_suffix))).select('"input"')
df_inputs.to_pandas().head()

Unnamed: 0,input
0,[INST] <>\nYour output will be parsed by a com...
1,[INST] <>\nYour output will be parsed by a com...
2,[INST] <>\nYour output will be parsed by a com...
3,[INST] <>\nYour output will be parsed by a com...
4,[INST] <>\nYour output will be parsed by a com...


In [1]:
res = llama_model_ref.predict(
    deployment_name=DEPLOYMENT_NAME,
    data=input_df
)