In [1]:
import io
import sys

from pyspark.ml import Pipeline, PipelineModel
from pyspark.sql import SparkSession

import numpy as np

In [2]:
# Используем как путь откуда загрузить модель
MODEL_PATH = 'spark_ml_model'
#input_file - путь к файлу с данными для которых нужно предсказать ctr
input_file = 'test.parquet'
#output_file - путь по которому нужно сохранить файл с результатами [ads_id, prediction]
output_file = 'result'

In [3]:
spark = SparkSession.builder.appName('PySparkMLPredict').getOrCreate()

In [4]:
spark

### check the dataset

In [5]:
test = spark.read.parquet(input_file)

In [6]:
test.show(5)

+-----+---------------------+---------+------+------+----------------+---------+-----------------+
|ad_id|target_audience_count|has_video|is_cpm|is_cpc|         ad_cost|day_count|              ctr|
+-----+---------------------+---------+------+------+----------------+---------+-----------------+
|    2|     11012.2068140534|        1|     1|     0|196.691891825393|       17| 0.50005065193925|
|    3|     9923.69112524699|        1|     1|     0|202.617038691842|       15|0.637132195277704|
|    4|     10202.3140990505|        1|     1|     0|203.496891469936|       15|0.783706394973096|
|   10|     10239.9431887051|        1|     1|     0|195.804239443196|       15| 1.01044552869544|
|   13|     8373.52511906263|        1|     1|     0|202.221614839989|       13| 1.05570252090352|
+-----+---------------------+---------+------+------+----------------+---------+-----------------+
only showing top 5 rows



### download the model

In [7]:
model = PipelineModel.load(MODEL_PATH)

In [8]:
prediction = model.transform(test)

In [9]:
prediction.show(5)

+-----+---------------------+---------+------+------+----------------+---------+-----------------+--------------------+------------------+
|ad_id|target_audience_count|has_video|is_cpm|is_cpc|         ad_cost|day_count|              ctr|            features|        prediction|
+-----+---------------------+---------+------+------+----------------+---------+-----------------+--------------------+------------------+
|    2|     11012.2068140534|        1|     1|     0|196.691891825393|       17| 0.50005065193925|[1.0,1.0,0.0,196....|1.3012103231330485|
|    3|     9923.69112524699|        1|     1|     0|202.617038691842|       15|0.637132195277704|[1.0,1.0,0.0,202....| 2.102279397581597|
|    4|     10202.3140990505|        1|     1|     0|203.496891469936|       15|0.783706394973096|[1.0,1.0,0.0,203....|1.8965282187698946|
|   10|     10239.9431887051|        1|     1|     0|195.804239443196|       15| 1.01044552869544|[1.0,1.0,0.0,195....|1.8716882642094221|
|   13|     8373.5251190626

In [10]:
df_to_save = prediction[['ad_id', 'prediction']]

In [11]:
df_to_save.coalesce(1).write.parquet(output_file)