<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [None]:
import argparse
import os
from pyspark.sql import SparkSession
import mlflow
from mlflow.tracking import MlflowClient
from pyspark.sql.functions import col

In [None]:
def process(spark, data_path, result_path):
    """
    Основной процесс задачи: загрузка модели, применение и сохранение результатов.

    :param spark: SparkSession
    :param data_path: путь до датасета
    :param result_path: путь сохранения результата
    """
    data = spark.read.parquet(data_path)

    model_uri = "models:/e-lavrushkin/v3"
    model = mlflow.pyfunc.spark_udf(spark, model_uri=model_uri)


    predictions = data.withColumn("prediction", model(*[col(c) for c in data.columns]))

    predictions.write.mode("overwrite").parquet(result_path)

In [None]:
def main(data, result):
    spark = _spark_session()
    process(spark, data, result)

In [None]:
def _spark_session():
    """
    Создание SparkSession.

    :return: SparkSession
    """
    return SparkSession.builder.appName('PySparkPredict').getOrCreate()

In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, default='data.parquet', help='Please set datasets path.')
    parser.add_argument('--result', type=str, default='result', help='Please set result path.')
    args = parser.parse_args()
    data = args.data
    result = args.result
    main(data, result)