## https://qiita.com/taka_yayoi/items/b65197128ee698d87910

### series to series

In [44]:
import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import LongType
from pyspark.sql import SparkSession
spark = SparkSession \
    .builder \
    .appName("example") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()

### define a function

In [45]:
def multiply_func(a: pd.Series, b: pd.Series)->pd.Series:
    return a*b

### execute the function as pandas_udf

In [46]:
x = pd.Series([1,2,3])
display(multiply_func(x,x))

0    1
1    4
2    9
dtype: int64

### create spark DataFrame

In [47]:
df = spark.createDataFrame(pd.DataFrame(x, columns=["x"]))

In [48]:
df.show()

+---+
|  x|
+---+
|  1|
|  2|
|  3|
+---+



### define the spark udf and execute it

In [49]:
multiply = pandas_udf(multiply_func, returnType=LongType())

In [50]:
df.select(multiply(col("x"), col("x"))).show()

+-------------------+
|multiply_func(x, x)|
+-------------------+
|                  1|
|                  4|
|                  9|
+-------------------+



## series iterator to series iterator

In [51]:
import pandas as pd
from typing import Iterator
from pyspark.sql.functions import col, pandas_udf, struct

In [52]:
pdf = pd.DataFrame([1, 2, 3], columns=["x"])
df = spark.createDataFrame(pdf)

In [53]:
@pandas_udf("long")
def plus_one(batch_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
    for x in batch_iter:
        yield x + 1
#https://www.sejuku.net/blog/23716
#ジェネレータメモ

In [54]:
df.select(plus_one(col("x"))).show()

+-----------+
|plus_one(x)|
+-----------+
|          2|
|          3|
|          4|
+-----------+



In [55]:
y_bc = spark.sparkContext.broadcast(1)

@pandas_udf("long")
def plus_y(batch_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
    y = y_bc.value  # initialize states
    try:
        for x in batch_iter:
            yield x + y
    finally:
        pass  # release resources here, if any

In [56]:
df.select(plus_y(col("x"))).show()

+---------+
|plus_y(x)|
+---------+
|        2|
|        3|
|        4|
+---------+



### multiple series iterator to series iterator

In [57]:
from typing import Iterator, Tuple
import pandas as pd

from pyspark.sql.functions import col, pandas_udf, struct

In [58]:
pdf = pd.DataFrame([1, 2, 3], columns=["x"])
df = spark.createDataFrame(pdf)

In [59]:
@pandas_udf("long")
def multiply_two_cols(
        iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:
    for a, b in iterator:
        yield a * b

In [60]:
df.select(multiply_two_cols("x", "x")).show()

+-----------------------+
|multiply_two_cols(x, x)|
+-----------------------+
|                      1|
|                      4|
|                      9|
+-----------------------+



### series to scaler udf

In [61]:
import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql import Window

In [62]:
df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
    ("id", "v"))

In [63]:
# Declare the function and create the UDF
@pandas_udf("double")
def mean_udf(v: pd.Series) -> float:
    return v.mean()

In [64]:
df.select(mean_udf(df['v'])).show()

+-----------+
|mean_udf(v)|
+-----------+
|        4.2|
+-----------+



In [65]:
df.groupby("id").agg(mean_udf(df['v'])).show()

+---+-----------+
| id|mean_udf(v)|
+---+-----------+
|  1|        1.5|
|  2|        6.0|
+---+-----------+



In [66]:
w = Window \
    .partitionBy('id') \
    .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
df.withColumn('mean_v', mean_udf(df['v']).over(w)).show()

+---+----+------+
| id|   v|mean_v|
+---+----+------+
|  1| 1.0|   1.5|
|  1| 2.0|   1.5|
|  2| 3.0|   6.0|
|  2| 5.0|   6.0|
|  2|10.0|   6.0|
+---+----+------+

