## 1. Import packages and data

In [None]:
#Import snowflake ML Packages
from snowflake.ml.feature_store import (
    FeatureStore,
    FeatureView,
    Entity,
    CreationMode
)
from snowflake.snowpark import Window, Session
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import *
from snowflake.ml.modeling.preprocessing import OrdinalEncoder, OneHotEncoder, StandardScaler
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.modeling.pipeline import Pipeline
from snowflake.ml.modeling.model_selection import GridSearchCV
from snowflake.ml.modeling.metrics import *
from snowflake.ml.registry import Registry

# Import other python packages
import streamlit as st
import numpy as np # linear algebra
import pandas as pd # data processing
import matplotlib.pyplot as plt
import seaborn as sns
import streamlit as st
from operator import itemgetter
import json

session = get_active_session()

In [None]:
snowflake_environment = session.sql('select current_user(), current_version()').collect()
from snowflake.snowpark.version import VERSION
from snowflake.ml import version

database = session.get_current_database()
schema = session.get_current_schema()
# Current Environment Details
print('User                        : {}'.format(snowflake_environment[0][0]))
print('Role                        : {}'.format(session.get_current_role()))
print('Database                    : {}'.format(database))
print('Schema                      : {}'.format(schema))
print('Warehouse                   : {}'.format(session.get_current_warehouse()))
print('Snowflake version           : {}'.format(snowflake_environment[0][1]))
print('Snowpark for Python version : {}.{}.{}'.format(VERSION[0],VERSION[1],VERSION[2]))
print('Snowflake ML version        : {}.{}.{}'.format(version.VERSION[0],version.VERSION[2],version.VERSION[4]))

In [None]:
claim_data = session.read.table("CONTAINER_RUNTIME_LAB.DATA.CLAIM_DATA")
claim_data_new = session.read.table("CONTAINER_RUNTIME_LAB.DATA.CLAIM_DATA_NEW")

In [None]:
# Add the SOURCE column to each DataFrame
claim_data = claim_data.with_column('SOURCE', lit('ORIGINAL'))
claim_data_new = claim_data_new.with_column('SOURCE', lit('NEW'))

# Get the list of column names for both DataFrames
columns_claim_data = [field.name for field in claim_data.schema.fields]
columns_claim_data_new = [field.name for field in claim_data_new.schema.fields]

# Create a set of all column names from both DataFrames
all_columns = set(columns_claim_data).union(set(columns_claim_data_new))

# Convert the set of all columns to a list and sort it to have consistent column order
all_columns_sorted = sorted(all_columns)

# Reorder columns in both DataFrames
claim_data = claim_data.select([col(c) for c in all_columns_sorted])
claim_data_new = claim_data_new.select([col(c) for c in all_columns_sorted])

# Now, perform the union
claim_data_combined = claim_data.union_all(claim_data_new)

In [None]:
fs = FeatureStore(
    session=session, 
    database=database,
    name='DATA',
    default_warehouse='CONTAINER_RUNTIME_WH',
    creation_mode=CreationMode.CREATE_IF_NOT_EXIST,
)

In [None]:
customer_fv = fs.get_feature_view(
    name = 'customer_data',
    version = 'V1'
)

In [None]:
claim_data_combined = claim_data_combined.replace('?', None)
counter = claim_data_combined.count()
print(f"New claim_data_combined count:  {counter}")
claim_data_combined.filter(col("POLICE_REPORT_AVAILABLE").is_null()).show(10)
# Calculate the mode of the 'POLICE_REPORT_AVAILABLE' column
mode_value = claim_data_combined.select(mode(col("POLICE_REPORT_AVAILABLE"))).collect()[0][0]
print(f"Fill NULL value in POLICY_REPORT_AVAIALLBE to the mode: {mode_value}")
# Fill NULL values with the mode
claim_data_combined = claim_data_combined.with_column("POLICE_REPORT_AVAILABLE", 
    when(col("POLICE_REPORT_AVAILABLE").is_null(), mode_value)
    .otherwise(col("POLICE_REPORT_AVAILABLE")))

In [None]:
data_combined = fs.generate_dataset(
    name="fraud_classification_demo",
    #version='v21',
    spine_df=claim_data_combined,
    features=[customer_fv],
    spine_label_cols = ["FRAUD_REPORTED"]
)

In [None]:
# Understand the policy duration from the policy start date to the indicent date
from snowflake.snowpark.functions import col

data_combined_df = data_combined.read.to_snowpark_dataframe().with_column("POLICY_DURATION",
    floor(datediff("month", col("POLICY_START_DATE"), col("INCIDENT_DATE"))))
# Let's all drop the date fields and the age field, as we did in part 1
data_combined_df = data_combined_df.drop("age", "INCIDENT_DATE", "POLICY_START_DATE")

In [None]:
print(f"row count: {data_combined_df.count()}")
st.dataframe(data_combined_df.limit(50))

## 2. Pre-process data and import model

In [None]:
# From the Model Registry, let's pull down the model we trained previous
reg = Registry(session=session, database_name="CONTAINER_RUNTIME_LAB", schema_name="DATA")
xgb_gs_fraud_model = reg.get_model("xgb_gs_fraud_model").last() #or we can use .version(<version_name>)

In [None]:
new_predictions = xgb_gs_fraud_model.run(data_combined_df, function_name='PREDICT')

In [None]:
original_data = new_predictions.filter(col('SOURCE')=="ORIGINAL")
new_data = new_predictions.filter(col('SOURCE')=="NEW")

In [None]:
#feature_columns = scaled_features + encoded_features + ordinal_encoded_features + other_features
label_column = ['FRAUD_REPORTED']
output_column = ['PREDICTED_FRAUD']

ACCURACY_ORIGINAL = accuracy_score(df=original_data, y_true_col_names=label_column, y_pred_col_names=output_column)
AUC_ORIGINAL = roc_auc_score(df=original_data, y_true_col_names=label_column, y_score_col_names=output_column)

print(f'Acccuracy (Original Data): {ACCURACY_ORIGINAL}')
print(f'AUC (Original Data): {AUC_ORIGINAL}')

In [None]:
xgb_gs_fraud_model.set_metric("Evaluation_Info", {'Acc': ACCURACY_ORIGINAL, 'Acc': AUC_ORIGINAL})

In [None]:
reg.get_model('xgb_gs_fraud_model').show_versions()
#xgb_gs_fraud_model.show_metrics()['Acc']

In [None]:
ACCURACY = accuracy_score(df=new_data, y_true_col_names=label_column, y_pred_col_names=output_column)
AUC = roc_auc_score(df=new_data, y_true_col_names=label_column, y_score_col_names=output_column)
print(f'Acccuracy (New Data): {ACCURACY}')
print(f'AUC (New Data): {AUC}')

## 4. Examine PSI and KDE across predictions and features

In [None]:
def calculate_psi(expected_array, actual_array, buckets=3):
    # Remove casting to int
    expected_array = expected_array.astype(float)
    actual_array = actual_array.astype(float)
    
    # Use equal-width bins
    breakpoints = np.linspace(np.min(expected_array), np.max(expected_array), buckets + 1)
    
    expected_counts = np.histogram(expected_array, bins=breakpoints)[0]
    actual_counts = np.histogram(actual_array, bins=breakpoints)[0]
    
    # Apply smoothing
    expected_percents = expected_counts / expected_counts.sum()
    actual_percents = actual_counts / actual_counts.sum()
    
    # Replace zeros to avoid division by zero or log of zero
    expected_percents = np.where(expected_percents == 0, 1e-6, expected_percents)
    actual_percents = np.where(actual_percents == 0, 1e-6, actual_percents)
    
    psi_values = (actual_percents - expected_percents) * np.log(actual_percents / expected_percents)
    total_psi = np.sum(psi_values)
    return total_psi

In [None]:
#reference_data = session.read.table("CONTAINER_RUNTIME_LAB.DATA.REFERENCE_DATA").to_pandas()
original_data.write.save_as_table('CONTAINER_RUNTIME_LAB.DATA.REFERENCE_DATA', mode='overwrite')
new_data.write.save_as_table('CONTAINER_RUNTIME_LAB.DATA.CURRENT_DATA', mode='overwrite')

original_data_pd = original_data.to_pandas()
new_predictions_pd = new_data.to_pandas()

In [None]:
plt.figure(figsize=(4, 3))
sns.set_context("notebook")
sns.kdeplot(data=original_data_pd['FRAUD_REPORTED'], label='training_data', fill=True, color='blue', common_norm=False)
sns.kdeplot(data=new_predictions_pd['FRAUD_REPORTED'], label='new_data', fill=True, color='red', common_norm=False)
plt.title('Kernel Density Estimate of "FRAUD_REPORTED" for Training Data and New Data')
plt.xlabel('Likelihood of Fraud')
plt.ylabel('Density')
plt.legend()
plt.show()

psi = calculate_psi(original_data_pd['FRAUD_REPORTED'],new_predictions_pd['FRAUD_REPORTED'])
print(f"Fraud PSI: {psi}")

# For Population Stability Index, this is generally the metrics we look at to determine if significant drift has occured.
# PSI < 0.1: No significant change
# 0.1 ≤ PSI < 0.2: Moderate change
# PSI ≥ 0.2: Significant change

In [None]:
# Given we are looking at binary data, We can also simply observe the proportion of fraud across two datasets, normalized
combined_data = pd.concat([original_data_pd, new_predictions_pd])

# Calculate proportions
proportions = combined_data.groupby(['SOURCE', 'FRAUD_REPORTED']).size().reset_index(name='Count')
proportions['Proportion'] = proportions.groupby('SOURCE')['Count'].transform(lambda x: x / x.sum())

# Plotting with Seaborn
sns.barplot(x='FRAUD_REPORTED', y='Proportion', hue='SOURCE', data=proportions)
plt.xlabel('Likelihood of Fraud')
plt.ylabel('Proportion')
plt.title('Proportion of "FRAUD_REPORTED" for Training Data and New Data')
plt.legend()
plt.show()

psi = calculate_psi(original_data_pd['FRAUD_REPORTED'],new_predictions_pd['FRAUD_REPORTED'])
print(f"Fraud PSI: {psi}")

# For Population Stability Index, this is generally the metrics we look at to determine if significant drift has occured.
# PSI < 0.1: No significant change
# 0.1 ≤ PSI < 0.2: Moderate change
# PSI ≥ 0.2: Significant change

In [None]:
# Let's get a list of columns, assuming both lists have the same fields. Let's also exclude POLICY_NUMBER
columns = (col for col in original_data_pd.columns if col not in ['POLICY_NUMBER'])

# List to store PSI values
psi_values = []

for column in columns:
    # Check if the column is numeric
    if np.issubdtype(original_data_pd[column].dtype, np.number):
        psi = calculate_psi(original_data_pd[column], new_predictions_pd[column])
        psi_values.append((column, psi))
    else:
        st.write(f"Skipping {column} - not numeric or boolean")
# Sort PSI values from highest to lowest
sorted_psi = sorted(psi_values, key=itemgetter(1), reverse=True)

# Create a DataFrame for easier manipulation
psi_df = pd.DataFrame(sorted_psi, columns=['Column', 'PSI']).set_index('Column')

# Display the top 5 columns with the highest PSI values using Streamlit
st.title("Top 5 Features with Highest PSI Values")
st.table(psi_df.head(5).style.format({'PSI': '{:.3f}'}))

In [None]:
plt.figure(figsize=(4, 3))
sns.set_context("notebook")
sns.kdeplot(data=original_data_pd['INCIDENT_HOUR_OF_THE_DAY'], label='training_data', fill=True, color='blue')
sns.kdeplot(data=new_predictions_pd['INCIDENT_HOUR_OF_THE_DAY'], label='new_data', fill=True, color='red')
plt.title('Kernel Density Estimate of "INCIDENT_HOUR_OF_THE_DAY" for Training Data and New Data')
plt.xlabel('Likelihood of Fraud')
plt.ylabel('Density')
plt.legend()
plt.show()
psi = calculate_psi(original_data_pd['INCIDENT_HOUR_OF_THE_DAY'],new_predictions_pd['INCIDENT_HOUR_OF_THE_DAY'])
print(f"Fraud PSI: {psi}")

## 5. Use Evidently to generate reports and automatically push to Streamlit + Email

In [None]:
def generate_sql_email_message(data: dict, file: str) -> str:
    from datetime import datetime
    import pandas as pd
    import pytz

    tz = pytz.timezone('Europe/London')
    current_timestamp = datetime.now(tz).strftime("%d-%m-%Y %H:%M:%S")
   
    # Generate summary 
    summary = data.get('summary', {})
    if not summary:
        return "Error: 'summary' key not found in data"

    # Generate table with column test status
    df = pd.DataFrame(data.get('tests', []))
    if not df.empty:
        df['column_name'] = df['parameters'].apply(lambda x: x.get('column_name', 'N/A'))
        df = df[['column_name', 'status']]
        table_html = df.to_html(index=False).replace("'", '"')
    else:
        table_html = "<p>No test data available</p>"
    
    email_content = f"""
    Date: {current_timestamp} <br>
    Successful tests: {summary.get('success_tests', 0)} <br>
    Failed tests: {summary.get('failed_tests', 0)} <br>
    Report: <a href="{file}">Download Report</a>
    
    {table_html}
    """
    # Remove non-ASCII characters
    email_content_ascii = email_content.encode('ascii', 'ignore').decode('ascii')
    # Escape single quotes
    email_content_escaped = email_content_ascii.replace("'", "''")

    sql = f"""
    CALL SYSTEM$SEND_SNOWFLAKE_NOTIFICATION(
        SNOWFLAKE.NOTIFICATION.TEXT_HTML('{email_content_escaped}'),
        SNOWFLAKE.NOTIFICATION.EMAIL_INTEGRATION_CONFIG(
            'my_email_int',
            'Drift Detection Report {current_timestamp}',
            ARRAY_CONSTRUCT('harley.chen@snowflake.com')
        )
    )
    """
    return sql

@sproc(session=session, name='evidently_monitor', stage_location='@MONITORING',  
       packages=['snowflake-snowpark-python', 'pandas', 'evidently', 'snowflake-ml-python', 'tabulate', 'pytz'],
       is_permanent=True, 
       replace=True)
def monitor_model(session: Session) -> dict:
    from evidently.test_preset import DataDriftTestPreset
    from evidently.test_suite import TestSuite
    from datetime import datetime
    import pandas as pd
    import pytz
    import traceback
    import json
    
    output = {}
    tz = pytz.timezone('Europe/London')
    feature_columns = ['INCIDENT_HOUR_OF_THE_DAY', 
                       '"""INSURED_OCCUPATION_encoded_handlers-cleaners"""',
                       '"""INSURED_OCCUPATION_encoded_prof-specialty"""',
                       'POLICY_DEDUCTABLE_SCALED'] # Just picking the top columns from the Calc_PSI cell.
                                                   # You can include all fields.
    # Load reference data
    reference = session.table("CONTAINER_RUNTIME_LAB.DATA.REFERENCE_DATA").select(feature_columns).to_pandas()
    output['reference_head'] = reference.head(5).to_dict()
    # Load current data
    current = session.table('CONTAINER_RUNTIME_LAB.DATA.CURRENT_DATA').select(feature_columns).to_pandas()
    output['current_head'] = current.head(5).to_dict()
    
    try:
        # Generate the report
        report = TestSuite(tests=[DataDriftTestPreset(stattest="psi", stattest_threshold=0.3)])
        report.run(reference_data=reference, current_data=current)
    except Exception as e:
        error_msg = f"Error generating report: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return {'Error': error_msg}

    try:
        # Upload report to stage
        timestamp = datetime.now(tz)
        timestamp1 = timestamp.strftime('%Y_%m_%d_%H_%M_%S')
        timestamp2 = timestamp.strftime('%Y/%m/%d/')
        filename = f"/tmp/{timestamp1}.html"
        stage_filename = f"@MONITORING/report/{timestamp2}"
        report.save_html(filename)
        session.file.put(filename, stage_filename, auto_compress=False, overwrite=True)
        download_url_query = f"SELECT GET_PRESIGNED_URL('@MONITORING', 'report/{timestamp2}{timestamp1}.html') AS DOWNLOAD_LINK"
        download_url = session.sql(download_url_query).collect()[0]['DOWNLOAD_LINK']
    except Exception as e:
        error_msg = f"Error uploading report: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return {'Error': error_msg}
    
    try:
        test_summary = report.as_dict()
        print(f"test_summary keys: {test_summary.keys()}")
    except Exception as e:
        error_msg = f"Error converting report to dict: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        return {'Error': error_msg}

    # Send an email if any tests failed
    if test_summary.get('summary', {}).get('failed_tests', 0) > 0:
        try:
            sql_query = generate_sql_email_message(test_summary, download_url)
            print(f"Generated SQL query: {sql_query}")
            session.sql(sql_query).collect()
        except Exception as e:
            error_msg = f"Error sending email: {str(e)}\n{traceback.format_exc()}"
            print(error_msg)
            return {'Error': error_msg}

    result = {**test_summary, **output}
    print(f"Final result keys: {result.keys()}")
    return result


In [None]:
import json

drift_detection_results_str = monitor_model(session)
drift_detection_results = json.loads(drift_detection_results_str)
drift_detection_results['summary']

In [None]:
def plot_evidently_results(drift_detection_results):
    # Data extraction (keep as is)
    summary_data = drift_detection_results['summary']
    test_data = pd.DataFrame(drift_detection_results['tests'])
    feature_drift = []
    for test in drift_detection_results['tests']:
        if 'parameters' in test and 'features' in test['parameters']:
            for feature, params in test['parameters']['features'].items():
                feature_drift.append({
                    'feature': feature,
                    'detected': params['detected'],
                    'score': params['score'],
                    'stattest': params['stattest'],
                    'threshold': params['threshold']
                })
    feature_drift_df = pd.DataFrame(feature_drift)

    # Plotting
    fig, ax = plt.subplots(1, 3, figsize=(20, 8))

    # Summary plot
    ax[0].bar(summary_data['by_status'].keys(), summary_data['by_status'].values(), color=['#FF9999', '#66B2FF'])
    ax[0].set_title('Test Summary', fontsize=16)
    ax[0].set_ylabel('Count', fontsize=12)
    ax[0].set_xlabel('Status', fontsize=12)
    ax[0].tick_params(axis='both', which='major', labelsize=10)

    # Drift scores bar plot
    feature_drift_df.plot(kind='bar', x='feature', y='score', ax=ax[1], color='#66B2FF')
    ax[1].set_title('Feature Drift Scores', fontsize=16)
    ax[1].set_ylabel('Drift Score', fontsize=12)
    ax[1].set_xlabel('Feature', fontsize=12)
    ax[1].tick_params(axis='both', which='major', labelsize=10)
    plt.setp(ax[1].get_xticklabels(), rotation=45, ha='right')

    # Pie chart for drifted features
    drifted_counts = feature_drift_df['detected'].value_counts()
    wedges, texts, autotexts = ax[2].pie(drifted_counts, labels=['No Drift', 'Drift'], 
                                         autopct='%1.1f%%', colors=['#66B2FF', '#FF9999'],
                                         textprops={'fontsize': 10})
    ax[2].set_title('Proportion of Drifted Features', fontsize=16)

    plt.tight_layout(pad=0.5)
    plt.show()

In [None]:
# We can also visualize the results from the drift detection
plot_evidently_results(drift_detection_results)

In [None]:
from snowflake.snowpark.functions import col
from snowflake.snowpark.types import LongType

# Note that I have combined the data to rerun our model, since the new data is not large and credible enough to train a new model on its own.
# If your new data is credible enough, you can simply just take the new dataset and retrain the model only on the new dataset.
train_data_all, test_data_all = data_combined_df.drop('source').random_split(weights = [0.8, 0.2], seed = 42) 


train_data_all = train_data_all.with_column("FRAUD_REPORTED", col("FRAUD_REPORTED").astype(LongType()))
test_data_all = test_data_all.with_column("FRAUD_REPORTED", col("FRAUD_REPORTED").astype(LongType()))

# Determine the label column name
label_column_all = ['FRAUD_REPORTED']
output_column_all = ['PREDICTED_FRAUD']

In [None]:
# Define the categories with their specific order
categories = {
    "INSURED_EDUCATION_LEVEL": np.array(["High School", "Associate", "College", "Masters", "JD", "MD", "PhD"]),
    "INCIDENT_SEVERITY": np.array(["Trivial Damage", "Minor Damage", "Major Damage", "Total Loss"])
}
# Create the OrdinalEncoder with specified categories
OrdinalEncoding = OrdinalEncoder(
    input_cols=["INSURED_EDUCATION_LEVEL", "INCIDENT_SEVERITY"],
    output_cols=["INSURED_EDUCATION_LEVEL_OE", "INCIDENT_SEVERITY_OE"],
    categories=categories,
    handle_unknown="use_encoded_value",
    unknown_value=-1,
    drop_input_cols=True
)

# Define the columns to encode
columns_to_encode = [
    "INSURED_SEX",
    "INSURED_OCCUPATION",
    "INCIDENT_TYPE",
    "AUTHORITIES_CONTACTED",
    "POLICE_REPORT_AVAILABLE"
]
# Create a OneHotEncoder instance
OneHotEncoding = OneHotEncoder(
    input_cols=columns_to_encode,
    output_cols=[f"{col}_encoded" for col in columns_to_encode],
    drop_input_cols=True,  # Keep original columns
    handle_unknown='ignore'  # Ignore any unknown categories during transform
)

# Define the columns to scale
columns_to_scale = [
    'POLICY_LENGTH_MONTH',
    'POLICY_DEDUCTABLE',
    'POLICY_ANNUAL_PREMIUM',
    'CLAIM_AMOUNT',
    'POLICY_DURATION'
]
# Create the StandardScaler
StandardScaling = StandardScaler(
    input_cols=columns_to_scale,
    output_cols=[f"{col}_SCALED" for col in columns_to_scale],
    with_mean=True,
    with_std=True,
    drop_input_cols=True  # Keep original columns
)

# Determine the label column name
# feature_columns = train_data.columns.remove('FRAUD_REPORTED_LONG')
label_column = ['FRAUD_REPORTED']
output_column = ['PREDICTED_FRAUD']

# # Initially, we can run this under the XGB Classifier model. However, you will notice that
# # the model overfits on the training data and performs poorly on the test dataset
# xgbmodel = XGBClassifier(
#     random_state=1, 
#     #input_cols=feature_columns,    #here we are passing all columns so we have commented out. If you have specific columns set as features, you should specify them here
#     label_cols=label_column,
#     output_cols=output_column
#     )

xgb_grid_search = GridSearchCV(
    estimator=XGBClassifier(),
    param_grid={
        "n_estimators":[10, 20, 30, 50, 100, 150, 200, 250, 300],
        "subsample": [0.9, 0.5, 0.2],
        "max_depth": range(2,10,1),
        "learning_rate":[0.1, 0.06, 0.05, 0.03, 0.01, 0.005, 0.002, 0.001],
    },
    n_jobs = -1,
    #input_cols=feature_columns,    #here we are passing all columns so we have commented out. 
                                    #If you have specific columns set as features, you should specify them here
    label_cols=label_column,
    output_cols=output_column,
)

# xgb_gs_fitted = xgb_grid_search.fit(train_data)

model_pipeline = Pipeline(
    steps=[
        ("Ordinal_encoding",OrdinalEncoding),
        ("OneHotEncoding",OneHotEncoding),
        ("standardscaler",StandardScaling),
        #("XGBClassifier", xgbmodel)
        ("CV_XGBClassifier", xgb_grid_search)
    ]
)

In [None]:
wh = str(session.get_current_warehouse()).strip('"')
print(f"Current warehouse: {wh}")
print(session.sql(f"SHOW WAREHOUSES LIKE '{wh}';").collect())

session.sql(f"alter warehouse {session.get_current_warehouse()} set WAREHOUSE_SIZE = LARGE").collect()

#Give Snowflake a few seconds to change WH sizes
import time
time.sleep(5)

print(session.sql(f"SHOW WAREHOUSES LIKE '{wh}';").collect())

In [None]:
# Typically trains for about 3mins in a large wh
xgb_gs_fitted = model_pipeline.fit(train_data_all)

In [None]:
wh = str(session.get_current_warehouse()).strip('"')
print(f"Current warehouse: {wh}")
print(session.sql(f"SHOW WAREHOUSES LIKE '{wh}';").collect())

session.sql(f"alter warehouse {session.get_current_warehouse()} set WAREHOUSE_SIZE = XSMALL").collect()

#Give Snowflake a few seconds to change WH sizes
import time
time.sleep(5)

print(session.sql(f"SHOW WAREHOUSES LIKE '{wh}';").collect())

In [None]:
gb_gs_train = xgb_gs_fitted.predict(train_data_all)
xgb_gs_predictions = xgb_gs_fitted.predict(test_data_all)

In [None]:
ACCURACY_NEW = accuracy_score(df=xgb_gs_predictions, y_true_col_names=label_column_all, y_pred_col_names=output_column_all)
print(f'Acccuracy (New Data): {ACCURACY}')
print(f'Acccuracy (Retrained Model on Test Data): {ACCURACY_NEW}')


AUC_NEW = roc_auc_score(df=xgb_gs_predictions, y_true_col_names=label_column, y_score_col_names=output_column)
print(f'AUC (New Data): {AUC}')
print(f'AUC (Retrained Model on Test Data): {AUC_NEW}')

In [None]:
# FUNCTION used to iterate the model version so we can automatically 
# create the next version number
import ast
import builtins  # Import the builtins module
#from snowflake.snowpark import functions as F 

def get_next_version(reg, model_name) -> str:
    """
    Returns the next version of a model based on the existing versions in the registry.

    Args:
        reg: The registry object that provides access to the models.
        model_name: The name of the model.

    Returns:
        str: The next version of the model in the format "V_".

    Raises:
        ValueError: If the version list for the model is empty or if the version format is invalid.
    """
    models = reg.show_models()
    if models.empty:
        return "V_1"
    elif model_name not in models["name"].to_list():
        return "V_1"
    max_version_number = builtins.max(  
        [
            int(version.split("_")[-1])
            for version in ast.literal_eval(
                models.loc[models["name"] == model_name, "versions"].values[0]
            )
        ]
    )
    return f"V_{max_version_number + 1}"

## 6. Retrain the model based on new + old data; register model if Accuracy improves

In [None]:
def train_model(session:Session, reg, new_model, model_name, acc_metric_old, acc_metric_new) -> str:
    model_version = get_next_version(reg, model_name)
    
    if acc_metric_old < acc_metric_new:
            # Set new mode las default model
        registered_model = reg.log_model(
            new_model,
            model_name=model_name,
            version_name=model_version,
            conda_dependencies=["snowflake-ml-python"],
            comment="Model trained using GridsearchCV in Snowpark to predict fraud claims",
            metrics={"Acc": acc_metric_new},
            options= {"relax_version": False}
        )
        reg.get_model(model_name).default = model_version
        return f"Registered new model with version {registered_model.version_name} as the performance has imporved \nPrevious Accuracy Metric: {acc_metric_old}\nNew Accuracy Metric: {acc_metric_new}"
    else:
        return f"Model not updated as the model accuracy has not meaningfully improved."

In [None]:
train_model(Session, reg, xgb_gs_fitted, 'XGB_GS_FRAUD_MODEL', ACCURACY, ACCURACY_NEW)