In [2]:
from pyspark.sql import SparkSession

spark = (
    SparkSession
        .builder
        .master("local[*]")
        .appName("UDF")
        .getOrCreate()
)

In [3]:
train = spark.read.csv("../data/train.csv", inferSchema=True, header=True).cache()

                                                                                

In [4]:
from pyspark.sql import functions as f
from pyspark.sql.functions import col
from pyspark.sql.types import LongType

def to_months(ms):
    return ms // 31536000000 // 12 #1 year = 31536000000 ms

to_months_udf = f.udf(to_months, LongType())

In [5]:
%%time
(
    train
        .select("content_id", to_months_udf("timestamp").alias("month"))
        .groupBy("content_id")
        .mean("month")
        .show()
)



+----------+----------+
|content_id|avg(month)|
+----------+----------+
|      4519|       0.0|
|      4818|       0.0|
|      5518|       0.0|
|     13285|       0.0|
|     12027|       0.0|
|       833|       0.0|
|      9427|       0.0|
|       496|       0.0|
|      5156|       0.0|
|      6336|       0.0|
|      2866|       0.0|
|      1959|       0.0|
|      7982|       0.0|
|     23336|       0.0|
|       148|       0.0|
|      1342|       0.0|
|      1088|       0.0|
|      1580|       0.0|
|       471|       0.0|
|      2122|       0.0|
+----------+----------+
only showing top 20 rows

CPU times: user 65.7 ms, sys: 92.7 ms, total: 158 ms
Wall time: 4min 7s


                                                                                

In [6]:
%%time
(
 train
    .select("content_id", (col("timestamp") / 31536000000 / 12).alias("month"))
    .groupBy("content_id")
    .mean("month")
    .show()
)



+----------+--------------------+
|content_id|          avg(month)|
+----------+--------------------+
|      4519|0.015429677931821352|
|      4818|0.016230649272587192|
|      5518| 0.02103952664360606|
|     13285| 0.04141825984987046|
|     12027| 0.04150963056316861|
|       833|0.022584471034752644|
|      9427| 0.02158891106034129|
|       496|0.020155351740181845|
|      5156|0.018064186032779997|
|      6336|0.019195051162991618|
|      2866|0.026921431007821156|
|      1959| 0.02670915106444127|
|      7982| 0.02122819619071833|
|     23336|0.019260306662998706|
|       148|0.017910791660268703|
|      1342| 0.02195320830526077|
|      1088| 0.02461949633696784|
|      1580|0.022461073541947943|
|       471| 0.02611517946736722|
|      2122|0.026057020501248997|
+----------+--------------------+
only showing top 20 rows

CPU times: user 12.9 ms, sys: 7.53 ms, total: 20.4 ms
Wall time: 4.74 s


                                                                                

Создадим Pandas UDFs (a.k.a. Vectorized UDFs)

In [7]:
import pandas as pd
from pyspark.sql.functions import pandas_udf

@pandas_udf('double')
def pandas_to_months(series: pd.Series) -> pd.Series:
    return series / 31536000000 / 12

In [8]:
%%time
(
    train
        .select("content_id", pandas_to_months("timestamp").alias("month"))
        .groupBy("content_id")
        .mean("month")
        .show()
)



+----------+--------------------+
|content_id|          avg(month)|
+----------+--------------------+
|      4519|0.015429677931821352|
|      4818|0.016230649272587192|
|      5518| 0.02103952664360606|
|     13285| 0.04141825984987046|
|     12027| 0.04150963056316861|
|       833|0.022584471034752644|
|      9427| 0.02158891106034129|
|       496|0.020155351740181845|
|      5156|0.018064186032779997|
|      6336|0.019195051162991618|
|      2866|0.026921431007821156|
|      1959| 0.02670915106444127|
|      7982| 0.02122819619071833|
|     23336|0.019260306662998706|
|       148|0.017910791660268703|
|      1342| 0.02195320830526077|
|      1088| 0.02461949633696784|
|      1580|0.022461073541947943|
|       471| 0.02611517946736722|
|      2122|0.026057020501248997|
+----------+--------------------+
only showing top 20 rows

CPU times: user 32.5 ms, sys: 40.9 ms, total: 73.4 ms
Wall time: 24.4 s


                                                                                