In [1]:
! pip install fosforml

Collecting fosforml
[?25l  Downloading https://files.pythonhosted.org/packages/7b/f9/245945c0ff00abf3f5d47c5ba599e386b85f966e7b1a14e177448b46c26b/fosforml-1.1.2-py3-none-any.whl (43kB)
[K     |████████████████████████████████| 51kB 490kB/s eta 0:00:01
[?25hCollecting scikit-learn==1.3.2
[?25l  Downloading https://files.pythonhosted.org/packages/25/89/dce01a35d354159dcc901e3c7e7eb3fe98de5cb3639c6cd39518d8830caa/scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.9MB)
[K     |████████████████████████████████| 10.9MB 1.8MB/s eta 0:00:01
[?25hCollecting snowflake-ml-python==1.5.0; python_version <= "3.9"
[?25l  Downloading https://files.pythonhosted.org/packages/80/72/c0fa5a9bc811a59a5a1c7113ff89676ed1629d7d6463db8c1a8c97a8b5f6/snowflake_ml_python-1.5.0-py3-none-any.whl (1.9MB)
[K     |████████████████████████████████| 1.9MB 42.1MB/s eta 0:00:01
[?25hCollecting cloudpickle==2.2.1
  Downloading https://files.pythonhosted.org/packages/15/80/44286939ca215e

Collecting matplotlib
[?25l  Downloading https://files.pythonhosted.org/packages/8e/67/e75134cb83d2e533e46d72e2033a413772efdc18291beb981f5d574a829f/matplotlib-3.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3MB)
[K     |████████████████████████████████| 8.3MB 49.7MB/s eta 0:00:01
[?25hCollecting plotly
[?25l  Downloading https://files.pythonhosted.org/packages/0b/f8/b65cdd2be32e442c4efe7b672f73c90b05eab5a7f3f4115efe181d432c60/plotly-5.22.0-py3-none-any.whl (16.4MB)
[K     |████████████████████████████████| 16.4MB 35.1MB/s eta 0:00:01
[?25hCollecting graphviz
[?25l  Downloading https://files.pythonhosted.org/packages/00/be/d59db2d1d52697c6adc9eacaf50e8965b6345cc143f671e1ed068818d5cf/graphviz-0.20.3-py3-none-any.whl (47kB)
[K     |████████████████████████████████| 51kB 9.0MB/s  eta 0:00:01
[?25hCollecting python-dateutil>=2.8.2
[?25l  Downloading https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_d

[?25l  Downloading https://files.pythonhosted.org/packages/32/3f/c02268d0c6fb6b3958bdda673c17b315c821d97df29ae6969f20fb49388a/pillow-10.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.4MB)
[K     |████████████████████████████████| 4.4MB 54.8MB/s eta 0:00:01
[?25hCollecting contourpy>=1.0.1
[?25l  Downloading https://files.pythonhosted.org/packages/31/a2/2f12e3a6e45935ff694654b710961b03310b0e1ec997ee9f416d3c873f87/contourpy-1.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (304kB)
[K     |████████████████████████████████| 307kB 27.2MB/s eta 0:00:01
[?25hCollecting fonttools>=4.22.0
[?25l  Downloading https://files.pythonhosted.org/packages/7b/30/ad4483dfc5a1999f26b7bc5edc311576f433a3e00dd8aea01f2099c3a29f/fonttools-4.53.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6MB)
[K     |████████████████████████████████| 4.6MB 35.1MB/s eta 0:00:01
[?25hCollecting cycler>=0.10
  Downloading https://files.pythonhosted.org/packages/e7/05/c19819d5

In [None]:
from snowflake.snowpark.session import Session
from snowflake.ml.registry.registry import Registry
import datetime, json
from snowflake.ml.modeling.pipeline import Pipeline
from snowflake.ml.modeling.preprocessing import MinMaxScaler, LabelEncoder, OneHotEncoder

def get_feature_columns(session, input_table, target_column):

    df = session.table(input_table)
    # schema_fields = df.schema.fields
    features = df.columns
    features.remove(target_column)
    desc_sql_query = 'DESCRIBE TABLE '+ input_table
    df_schema = session.sql(desc_sql_query).collect()
    categorical_types = ['VARCHAR', 'CHAR', 'STRING', 'TEXT', 'BOOL']
    categorical_features = []
    for row in df_schema:
        for typ in categorical_types:
            if typ in row['type']:
                categorical_features.append(row['name'])
                break
    numerical_features = list(set(features) - set(categorical_features))
    print("numerical features: ", numerical_features)
    # identify columns for labelencoding and onehotencoding
    le_column_features = categorical_features
    oh_column_features = []
    if len(categorical_features) >= 1:
        for column in categorical_features:
            if df.select(df[column]).distinct().count() < 10:
                oh_column_features.append(column)
    print("le_features: ", le_column_features)
    print("oh_features: ", oh_column_features)
    return df, numerical_features, le_column_features, oh_column_features


def create_and_run_preprocessing(df, numerical_features, le_column_features, oh_column_features):
    # pipeline steps
    print(df.show(n=10))
    categorical_pp = numerical_pp = dict()
    for column in le_column_features:
        key = 'le_'+column
        categorical_pp[key] = LabelEncoder(input_cols=column, output_cols=column)
    if len(oh_column_features) > 0:
        categorical_pp['oh_enc'] = OneHotEncoder(input_cols=oh_column_features, output_cols=oh_column_features,
                                                 handle_unknown='ignore')
    numerical_pp['scaler'] = MinMaxScaler(input_cols=numerical_features, output_cols=numerical_features)

    steps = [(key, categorical_pp[key]) for key in categorical_pp if categorical_pp[key] != []] + \
            [(key, numerical_pp[key]) for key in numerical_pp if numerical_features != []]

    print("df.columns =", df.columns)
    print("steps =", steps)
    # Run preprocessing pipeline steps
    df = Pipeline(steps=steps).fit(df).transform(df)
    # categorical_pp = {f'le_{column}':LabelEncoder(input_cols=column, output_cols=column) for column in le_column_features}
    # if len(oh_column_features)>0:
    #     categorical_pp['oh_enc'] = OneHotEncoder(input_cols=oh_column_features, output_cols=oh_column_features, handle_unknown='ignore')
    # numerical_pp = {
    #     'scaler': MinMaxScaler(input_cols=numerical_features, output_cols=numerical_features)
    # }
    # steps = [(key, categorical_pp[key]) for key in categorical_pp if categorical_pp[key]!=[]] + \
    # [(key, numerical_pp[key]) for key in numerical_pp if numerical_features!=[]]  
    # print("steps =", steps)   
    # Run preprocessing pipeline steps
    # df = Pipeline(steps=steps).fit(df).transform(df)
    # df_train, df_test = df.random_split(weights=[0.8, 0.2], seed=0)
    for col_name in df.schema.names:
        new_col = col_name.replace('.', '_')
        df = df.withColumnRenamed(col_name, new_col)
    print(df.show(n=10))
    return df


def batch_prediction(session, model_id, version_id, input_table, filter_cond='', output_table=''):
    reg = Registry(session=session)
    m = reg.get_model(model_id)
    mv = m.version(version_id)
    source = mv.get_metric('source')
    if source.upper() == 'EXPERIMENT':
        ds_query = "select METADATA:metrics:dataset_details as dataset from INFORMATION_SCHEMA.MODEL_VERSIONS where MODEL_NAME='"+model_id+"';"
        dataset_info = session.sql(ds_query).collect()
        dataset_list = dataset_info[0]['DATASET']
        target_column = eval(dataset_list)[0].get('target_column')
        df, numerical_features, le_column_features, oh_column_features = get_feature_columns(session, input_table, target_column)
        data = create_and_run_preprocessing(df, numerical_features, le_column_features, oh_column_features)
        remote_prediction = mv.run(data, function_name="predict")
        table_prefix = "PREDICTION_"
        new_table_name = (table_prefix + output_table).upper()
        session.write_pandas(remote_prediction, new_table_name, auto_create_table=True, overwrite=True)
        return output_table

    else:
        fetch_data_sql = "select * from " + input_table + " " + filter_cond
        df = session.sql(fetch_data_sql).collect()
        remote_prediction = mv.run(df, function_name="predict")
        table_prefix = "PREDICTION_"
        new_table_name = (table_prefix + output_table).upper()
        session.write_pandas(remote_prediction, new_table_name, auto_create_table=True, overwrite=True)
        return output_table


import time
CONNECTION_PARAMETERS = {
    "account": "ug94937.us-east4.gcp",
    "user":"ADITYASINGH",
    "password": "Enlightme#2024",
    "role": "ADITYASINGH",
    "database": "FIRST_DB",
    "warehouse": "FOSFOR_INSIGHT_WH",
    "schema": "PUBLIC", }

start_time = time.time()
session = Session.builder.configs(CONNECTION_PARAMETERS).create()

batch_prediction(session, 'MODEL_68BCE134_6D7A_4E24_9591_86266438ACF9_FDC_TESTBYPRAKHAR', 'V1', 'EMPLOYEE', '', '99999999')