In [0]:
import mlflow
import mlflow.spark
from mlflow.exceptions import RestException
from pyspark.ml.feature import StringIndexer
from pyspark.ml.recommendation import ALS
from mlflow.models import infer_signature

In [0]:
interactions = spark.table("MLOps.data.als_interactions_30d")

product_indexer = StringIndexer(
    inputCol="product_id",
    outputCol="product_id_idx",
    handleInvalid="skip"
)
interactions_indexed = product_indexer.fit(interactions).transform(interactions)

als = ALS(
    userCol="customer_id",
    itemCol="product_id_idx",
    ratingCol="interaction_weight",
    implicitPrefs=True,
    coldStartStrategy="drop",
    rank=20,
    maxIter=10,
    regParam=0.1
)

experiment_path = "/Workspace/Users/jung@ap-com.co.jp/mlops_demo_model/als_recommendation"

try:
    mlflow.set_experiment(experiment_path)
    print(f"Experiment found or created at: {experiment_path}")

except RestException as e:
    if "RESOURCE_DOES_NOT_EXIST" in str(e):
        experiment_id = mlflow.create_experiment(experiment_path)
        mlflow.set_experiment(experiment_path)
        print(f"Experiment created at: {experiment_path}, ID: {experiment_id}")
    else:
        raise e

train, test = interactions_indexed.randomSplit([0.8, 0.2], seed=42)

with mlflow.start_run(run_name="training") as run:
    model = als.fit(interactions_indexed)

    predictions = model.transform(test)

    sample_input = test.limit(5).toPandas()
    sample_output = predictions.limit(5).toPandas()
    signature = infer_signature(sample_input, sample_output)

    mlflow.spark.log_model(
        spark_model=model,
        artifact_path="als_model",
        signature=signature
    )

    mlflow.log_params({
        "rank": 20,
        "maxIter": 10,
        "regParam": 0.1
    })
    run_id = run.info.run_id

print(f"Training finished. Run ID: {run_id}")