In [None]:
from snowflake.snowpark.context import get_active_session
session = get_active_session()

In [None]:
df = session.table("CUSTOMER_DATA_1000")
df

In [None]:
train_df, test_df = df.random_split([0.8, 0.2], seed=42)
print(train_df.count(), test_df.count())

In [None]:
from snowflake.ml.modeling.impute import SimpleImputer
from snowflake.ml.modeling.preprocessing import StandardScaler, OrdinalEncoder
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.modeling.pipeline import Pipeline

In [None]:
# 補完（数値）AGE, ANNUAL_INCOME
numeric_imputer = SimpleImputer(
    strategy = "median",
    input_cols = ["AGE", "ANNUAL_INCOME"],
    output_cols = ["AGE", "ANNUAL_INCOME"],
    passthrough_cols = ["GENDER", "CHURN"],
    drop_input_cols = True
)

# 補完（文字列）GENDER
categorical_imputer = SimpleImputer(
    strategy = "most_frequent",
    input_cols = ["GENDER"],
    output_cols = ["GENDER"],
    passthrough_cols = ["AGE", "ANNUAL_INCOME", "CHURN"],
    drop_input_cols = True 
)

# encoding GENDER
encoder = OrdinalEncoder(
    input_cols = ["GENDER"],
    output_cols = ["GENDER_code"],
    passthrough_cols = ["AGE", "ANNUAL_INCOME", "CHURN"],
    drop_input_cols = True
)

# scaler AGE income
scaler = StandardScaler(
    input_cols = ["AGE", "ANNUAL_INCOME"],
    output_cols = ["AGE", "ANNUAL_INCOME"],
    passthrough_cols = ["GENDER_CODE", "CHURN"],
    drop_input_cols = True 
)

# modeling
xgb = XGBClassifier(
    input_cols = ["AGE", "ANNUAL_INCOME", "GENDER_CODE"],
    label_cols = ["CHURN"],
    output_cols = ["PREDICTED_CHURN"]
)

pipeline = Pipeline(steps=[
    ("num_imputer", numeric_imputer),
    ("cat_imputer", categorical_imputer),
    ("encoder", encoder),
    ("scaler", scaler),
    ("classifier", xgb)
])

model = pipeline.fit(train_df)

In [None]:
pred = model.predict(test_df)
pred.select("ID", "AGE", "GENDER_CODE","ANNUAL_INCOME","CHURN","PREDICTED_CHURN")

In [None]:
pred.show()

In [None]:
from sklearn.metrics import confusion_matrix

pred_df = pred.to_pandas()

cm = confusion_matrix(pred_df["CHURN"], pred_df["PREDICTED_CHURN"])
print(cm)

In [None]:
CREATE OR REPLACE DATABASE MY_ML_REGISTRY;
USE MY_ML_REGISTRY;
CREATE OR REPLACE SCHEMA MODELS;

In [None]:
USE SCHEMA MODELS;

In [None]:
from snowflake.ml.registry import Registry

registry = Registry(session = session)

registry.log_model(
    model = model,
    model_name = "customer_churn_model",
    version_name = "v1",
    comment = "hogehoge"
)