In [17]:
import snowflake.snowpark.functions as F
from snowflake.snowpark.functions import udf
from snowflake.snowpark.types import VariantType
from snowflake.snowpark.session import Session
from snowflake.snowpark.types import StructType, StructField, FloatType
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col
import os
import json

In [18]:
connection_parameters = {
    "account": os.getenv("SNOWFLAKE_ACCOUNT"),
    "user": os.getenv("SNOWFLAKE_USER"),
    "password": os.getenv("SNOWFLAKE_PASSWORD"),
    "schema": os.getenv("SNOWFLAKE_SCHEMA"),
    "database": os.getenv("SNOWFLAKE_DATABASE"),
    "role": os.getenv("SNOWFLAKE_ROLE"),
    "warehouse": os.getenv("SNOWFLAKE_WAREHOUSE"),
}

session = Session.builder.configs(connection_parameters).create()


In [19]:

print(f"Current Database and schema: {session.get_fully_qualified_current_schema()}")
print(f"Current Warehouse: {session.get_current_warehouse()}")

Current Database and schema: "MLOPS"."ADVERTISING"
Current Warehouse: "COMPUTE_WH"


In [20]:
ad_df = session.table("ADVERTISING")


In [21]:
# Stage for storing the trained model without specifying file format
session.sql("""
CREATE OR REPLACE STAGE ml_models
""").collect()



[Row(status='Stage area ML_MODELS successfully created.')]

In [22]:
session.sql(
    f"ALTER WAREHOUSE {session.get_current_warehouse()[1:-1]} SET WAREHOUSE_SIZE=LARGE;"
).collect()


[Row(status='Statement executed successfully.')]

In [23]:

create_procedure_sql = """
CREATE OR REPLACE PROCEDURE train()
  RETURNS VARIANT
  LANGUAGE PYTHON
  RUNTIME_VERSION = 3.11
  PACKAGES = ('snowflake-snowpark-python', 'scikit-learn', 'joblib')
  HANDLER = 'main'
AS $$
import os
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split, GridSearchCV
from joblib import dump

def main(session):
  df = session.table('ADVERTISING').to_pandas()
  X = df[['TV', 'RADIO', 'NEWSPAPER']]
  y = df['SALES']
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  
  numeric_features = ['TV', 'RADIO', 'NEWSPAPER']
  numeric_transformer = Pipeline(steps=[('poly', PolynomialFeatures()), ('scaler', StandardScaler())])
  preprocessor = ColumnTransformer(transformers=[('num', numeric_transformer, numeric_features)])
  pipeline = Pipeline(steps=[('preprocessor', preprocessor), ('classifier', LinearRegression(n_jobs=-1))])

  # Define parameter grid for GridSearchCV
  param_grid = {
      'preprocessor__num__poly__degree': [2, 3],
      'classifier__fit_intercept': [True, False]
  }

  model = GridSearchCV(pipeline, param_grid=param_grid, n_jobs=-1, cv=10)
  model.fit(X_train, y_train)
  
  model_file = os.path.join('/tmp', 'model.joblib')
  dump(model, model_file)
  session.file.put(model_file, "@ml_models", overwrite=True)
  
  return {"Best parameters": model.best_params_, "R2 score on Train": model.score(X_train, y_train), "R2 score on Test": model.score(X_test, y_test)}
$$;
"""
session.sql(create_procedure_sql).collect()

[Row(status='Function TRAIN successfully created.')]

In [24]:
session.sql(
    f"ALTER WAREHOUSE {session.get_current_warehouse()[1:-1]} SET WAREHOUSE_SIZE=XSMALL;"
).collect()

[Row(status='Statement executed successfully.')]

In [25]:
# Execute the stored procedure to train the model
session.sql("CALL train()").show()


------------------------------------------------
|"TRAIN"                                       |
------------------------------------------------
|{                                             |
|  "Best parameters": {                        |
|    "classifier__fit_intercept": true,        |
|    "preprocessor__num__poly__degree": 2      |
|  },                                          |
|  "R2 score on Test": 9.533174341074796e-01,  |
|  "R2 score on Train": 9.288133512730626e-01  |
|}                                             |
------------------------------------------------



In [26]:
from snowflake.snowpark.functions import udf
import snowflake.snowpark.types as T

# Define the UDF function
def predict_sales(tv: float, radio: float, newspaper: float) -> float:
    import os
    import sys
    from joblib import load
    import pandas as pd
    
    # Specify the import directory for the Snowflake stage files
    IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
    import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
    model_path = os.path.join(import_dir, 'model.joblib')
    model = load(model_path)
    input_data = pd.DataFrame([[tv, radio, newspaper]], columns=['TV', 'RADIO', 'NEWSPAPER'])
    prediction = model.predict(input_data)[0]
    
    return float(prediction)

# Register the UDF
session.udf.register(
    func=predict_sales, 
    name="predict_sales", 
    stage_location="@ml_models",
    input_types=[T.FloatType(), T.FloatType(), T.FloatType()],
    return_type=T.FloatType(),
    replace=True, 
    is_permanent=True, 
    imports=['@ml_models/model.joblib'],
    packages=['scikit-learn', 'pandas', 'joblib']
)

<snowflake.snowpark.udf.UserDefinedFunction at 0x297cf5e8f10>

In [27]:
from snowflake.snowpark.functions import col
import snowflake.snowpark.functions as F
advertising_df = session.table('ADVERTISING')
predicted_sales_df = advertising_df.select(
    col('TV'),
    col('RADIO'),
    col('NEWSPAPER'),
    F.call_udf('predict_sales', col('TV'), col('RADIO'), col('NEWSPAPER')).alias('PREDICTED_SALES')
)

In [28]:
predicted_sales_df.show()

------------------------------------------------------
|"TV"   |"RADIO"  |"NEWSPAPER"  |"PREDICTED_SALES"   |
------------------------------------------------------
|230.1  |37.8     |69.2         |21.886417875690817  |
|44.5   |39.3     |45.1         |10.372262452131155  |
|17.2   |45.9     |69.3         |9.113870015216659   |
|151.5  |41.3     |58.5         |18.388258741022366  |
|180.8  |10.8     |58.4         |16.125779196227914  |
|8.7    |48.9     |75.0         |8.805982700924098   |
|57.5   |32.8     |23.5         |10.576290497207951  |
|120.2  |19.6     |11.6         |13.689984286660831  |
|8.6    |2.1      |1.0          |5.743650332055633   |
|199.8  |2.6      |21.2         |16.197702901058705  |
------------------------------------------------------



In [29]:
try:
    current_warehouse = session.get_current_warehouse()
    session.sql(
        f"ALTER WAREHOUSE {current_warehouse[1:-1]} SUSPEND"
    ).collect()
    print(f'Warehouse {current_warehouse[1:-1]} has been suspended successfully.')
except Exception as e:
    print(f'Error suspending the warehouse: {e}')
finally:
    session.close()

Warehouse COMPUTE_WH has been suspended successfully.
