Article for inspiration: https://www.snowflake.com/blog/container-services-llama2-snowpark-ml/

Compute Pool: skhara_compute_gpu7

In [None]:
!pip install transformers

In [None]:
from snowflake.snowpark.session import Session
from snowflake.ml.registry import model_registry
from snowflake.ml.model import deploy_platforms

import json
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings("ignore")

In [None]:
connection_parameters = json.load(open('creds.json'))
session = Session.builder.configs(connection_parameters).create()

# LLAMA Model Setup

## Load LLAMA Model

In [None]:
HF_AUTH_TOKEN = "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" #Your token from Hugging Face

In [None]:
from transformers import pipeline
from snowflake.ml.model.models import huggingface_pipeline

llama_model = huggingface_pipeline.HuggingFacePipelineModel(task="text-generation",
                                                            model="meta-llama/Llama-2-7b-chat-hf",
                                                            token=HF_AUTH_TOKEN,
                                                            return_full_text=False,
                                                            max_new_tokens=100)

## Register the model

In [None]:
registry_name = 'SKHARA'
schema_name = 'BUILD_REGISTRY'

model_registry.create_model_registry(session= session,
                                     database_name= registry_name,
                                     schema_name= schema_name)

registry = model_registry.ModelRegistry(session= session,
                                        database_name= registry_name,
                                        schema_name= schema_name)

In [None]:
MODEL_NAME = "LLAMA2_MODEL_7b_CHAT"
MODEL_VERSION = "5"

llama_model_ref= registry.log_model(
    model_name=MODEL_NAME,
    model_version=MODEL_VERSION,
    model=llama_model
)

llama_model_ref

## Deploy Model

In [None]:
llama_model_ref.deploy(
    deployment_name="llama_predict",
    platform= deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES,
    options={"compute_pool": "SKHARA_COMPUTE_GPU3",
             "num_gpus": 1,
             # Remove the 'prebuilt_snowflake_image' argument below when running .deploy() for the first time
             "prebuilt_snowflake_image": "sfsenorthamerica-fcto-spc.registry.snowflakecomputing.com/skhara/build_registry/snowml_repo/116da812e88f2751324c6a16eb00de3726ed06a3:latest"
            },
    permanent = True
)

# I/O Setup

We will load a JSON file to a Snowflake Table. For prediction purposes, we have two options - use Snowpark DataFrame, use Local Pandas DataFrame.
For sake of simplicity, we will use a Local Pandas Dataframe with only tow rows. If the dataset is big, it is advised to use Snowpark Dataframes.

## Load Data

In [None]:
json_dataset = pd.read_json("frosty_dataset_generator/frosty_transcripts_all.jsonl", lines=True).convert_dtypes()
json_dataset.head()

In [None]:
TABLE_NAME = "AK_BUILD_DATA"
session.write_pandas(json_dataset, table_name=TABLE_NAME, auto_create_table=True, overwrite=True)

## Input: Prompt Engineering

In [None]:
session.sql('SELECT * from AK_BUILD_DATA LIMIT 5').to_pandas()

In [None]:
sdf_input = session.table('AK_BUILD_DATA')
df_local = sdf_input.limit(20).to_pandas()
df_local.head()

In [None]:
def add_prompt(transcript):
    prompt = f'''[INST] <PROMPT>
    Your output will be parsed by a computer program as a JSON object. Please respond ONLY with valid json that conforms to this JSON schema:
    {{
      "name": {{
        "type": "string",
        "description": "The name of the person calling"
      }},
      "location": {{
        "type": "string",
        "description": "The name of the location where the person is calling from."
      }},
      "toy_list": {{
        "type": "array",
        "description": "The list of toys requested by the person calling."
      }},
      "required": ["name", "location", "toy_list"]
    }}


    Example 1:
    Input: "{df_local['transcript'].iloc[0]}"
    Output: {{"name": {df_local['name'].iloc[0]}, "location": {df_local['location'].iloc[0]}, "toy_list": {df_local['toy_list'].iloc[0]}}}

    Example 2:
    Input: "{df_local['transcript'].iloc[1]}"
    Output: {{"name": {df_local['name'].iloc[1]}, "location": {df_local['location'].iloc[1]}, "toy_list": {df_local['toy_list'].iloc[1]}}}
    </PROMPT>

    Actual Input: {transcript}
    [/INST]
    '''
    return prompt

In [None]:
df_local['inputs'] = df_local['transcript'].apply(add_prompt)
print(df_local['inputs'].iloc[3])

## Output: Processing
Ensure that processing code conforms to the JSON Structure provided during Prompt Engineering.

In [None]:
import json
def format_output(output_string):
    try:
        outer_list = json.loads(output_string)
        generated_text_str = outer_list[0]['generated_text']
        
        end_pos = generated_text_str.rfind('}')
        if end_pos == -1:
            raise ValueError("No closing brace found in generated_text")
        json_str = generated_text_str[:end_pos + 1]
        
        generated_text_dict = json.loads(json_str)
        return generated_text_dict
    except:
        return 'Could not parse output'

# Get Predictions

## Get Deployed Model

In [None]:
registry_name = 'SKHARA'
schema_name = 'BUILD_REGISTRY'

registry = model_registry.ModelRegistry(session= session,
                                        database_name= registry_name,
                                        schema_name= schema_name)

In [None]:
model_list = registry.list_models()
model_list.to_pandas()

In [None]:
dep_list = registry.list_deployments(model_name='llama_predict', model_version=5)
dep_list.to_pandas()

In [None]:
model_name = 'LLAMA2_MODEL_7b_CHAT'
model = model_registry.ModelReference(registry=registry, model_name=model_name, model_version='5')

## Predict & See Outputs

In [None]:
res = model.predict(
    deployment_name= 'llama_predict',
    data= df_local[['inputs']]
)

In [None]:
for i in range(len(df_local)):
    print(f'\n\n **** Transcript # {i} ****')
    print(df_local['transcript'].iloc[i])
    print('\n')
    print(format_output(res['outputs'].iloc[i]))