In [1]:
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
conf = SparkConf()
conf.set("spark.driver.extraJavaOptions", "-Dio.netty.tryReflectionSetAccessible=true")
conf.set("spark.executor.extraJavaOptions", "-Dio.netty.tryReflectionSetAccessible=true")
spark = SparkSession.builder.config(conf=conf).getOrCreate()

In [6]:
from pyspark.sql.types import *
from pyspark.sql.functions import col, count, rand, collect_list, explode, struct, count, lit
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.functions import udf, log

In [3]:
df = spark.range(0, 10 * 1000 * 1000).withColumn('id', (col('id') / 10000).cast('integer')).withColumn('v', rand())

In [4]:

df.cache()
df.count()

10000000

In [7]:
@udf('double')
def plus_one(v):
    return v + 1

In [9]:
%time
df.withColumn('v', plus_one(df.v)).agg(count(col('v'))).show()

Wall time: 0 ns
+--------+
|count(v)|
+--------+
|10000000|
+--------+



In [11]:
df.withColumn('v', plus_one(df.v)).show()

+---+------------------+
| id|                 v|
+---+------------------+
|  0|1.0922988208363424|
|  0|1.6769451301868872|
|  0|1.3104433460112879|
|  0|1.2508636169081047|
|  0|1.9313903958594187|
|  0| 1.183732121648227|
|  0| 1.495188968500107|
|  0|1.5238617573339335|
|  0| 1.712705269168485|
|  0|1.5013970660553224|
|  0| 1.781912740230034|
|  0|1.2468175949410767|
|  0|1.7299890196075802|
|  0|1.5028392001569828|
|  0|1.5388136073133867|
|  0|1.2700669038386945|
|  0|1.8643552209460785|
|  0| 1.133052061707056|
|  0|1.5154161158068395|
|  0|1.7423036614842236|
+---+------------------+
only showing top 20 rows



In [12]:
import pandas as pd

@pandas_udf("double")
def pandas_plus_one(v: pd.Series) -> pd.Series:
    return v + 1

In [13]:
%timeit df.withColumn('v', pandas_plus_one(df.v)).agg(count(col('v'))).show()

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

3.31 s ± 69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
from scipy import stats

@udf('double')
def cdf(v):
    return float(stats.norm.cdf(v))

%timeit df.withColumn('cumulative_probability', cdf(df.v)).show()

+---+-------------------+----------------------+
| id|                  v|cumulative_probability|
+---+-------------------+----------------------+
|  0|0.09229882083634233|    0.5367696873956587|
|  0| 0.6769451301868872|    0.7507796157042829|
|  0|0.31044334601128776|    0.6218880820707676|
|  0|0.25086361690810466|    0.5990402227187244|
|  0| 0.9313903958594187|    0.8241741714617743|
|  0|0.18373212164822705|    0.5728881957857559|
|  0|0.49518896850010696|    0.6897666315801746|
|  0| 0.5238617573339335|    0.6998126522807955|
|  0|  0.712705269168485|    0.7619859223773182|
|  0| 0.5013970660553224|     0.691954147881658|
|  0| 0.7819127402300341|    0.7828670714230717|
|  0| 0.2468175949410767|      0.59747530355293|
|  0| 0.7299890196075801|    0.7673015517767563|
|  0| 0.5028392001569828|    0.6924613346944156|
|  0| 0.5388136073133867|    0.7049922632551755|
|  0|0.27006690383869447|    0.6064456083742223|
|  0| 0.8643552209460785|    0.8063036081285964|
|  0|0.1330520617070

+---+-------------------+----------------------+
| id|                  v|cumulative_probability|
+---+-------------------+----------------------+
|  0|0.09229882083634233|    0.5367696873956587|
|  0| 0.6769451301868872|    0.7507796157042829|
|  0|0.31044334601128776|    0.6218880820707676|
|  0|0.25086361690810466|    0.5990402227187244|
|  0| 0.9313903958594187|    0.8241741714617743|
|  0|0.18373212164822705|    0.5728881957857559|
|  0|0.49518896850010696|    0.6897666315801746|
|  0| 0.5238617573339335|    0.6998126522807955|
|  0|  0.712705269168485|    0.7619859223773182|
|  0| 0.5013970660553224|     0.691954147881658|
|  0| 0.7819127402300341|    0.7828670714230717|
|  0| 0.2468175949410767|      0.59747530355293|
|  0| 0.7299890196075801|    0.7673015517767563|
|  0| 0.5028392001569828|    0.6924613346944156|
|  0| 0.5388136073133867|    0.7049922632551755|
|  0|0.27006690383869447|    0.6064456083742223|
|  0| 0.8643552209460785|    0.8063036081285964|
|  0|0.1330520617070

In [16]:
from pyspark.sql import Row
@udf(ArrayType(df.schema))
def substract_mean(rows):
    vs = pd.Series([r.v for r in rows])
    vs = vs - vs.mean()
    return [Row(id=rows[i]['id'], v=float(vs[i])) for i in range(len(rows))]
  
%timeit df.groupby('id').agg(collect_list(struct(df['id'], df['v'])).alias('rows')).withColumn('new_rows', substract_mean(col('rows'))).withColumn('new_row', explode(col('new_rows'))).withColumn('id', col('new_row.id')).withColumn('v', col('new_row.v')).agg(count(col('v'))).show()

KeyboardInterrupt: 

In [None]:
# Input/output are both a pandas.DataFrame
def subtract_mean(pdf: pd.DataFrame) -> pd.DataFrame:
    return pdf.assign(v=pdf.v - pdf.v.mean())
%time
df.groupby('id').applyInPandas(subtract_mean, schema=df.schema).agg(count(col('v'))).show()

Wall time: 0 ns


In [None]:
df2 = df.withColumn('y', rand()).withColumn('x1', rand()).withColumn('x2', rand()).select('id', 'y', 'x1', 'x2')
df2.show()                                                               

In [None]:
import statsmodels.api as sm
# df has four columns: id, y, x1, x2
group_column = 'id'
y_column = 'y'
x_columns = ['x1', 'x2']
schema = df2.select(group_column, *x_columns).schema
# Input/output are both a pandas.DataFrame
def ols(pdf: pd.DataFrame) -> pd.DataFrame:
    group_key = pdf[group_column].iloc[0]
    y = pdf[y_column]
    X = pdf[x_columns]
    X = sm.add_constant(X)
    model = sm.OLS(y, X).fit()
    return pd.DataFrame([[group_key] + [model.params[i] for i in   x_columns]], columns=[group_column] + x_columns)
beta = df2.groupby(group_column).applyInPandas(ols, schema=schema)
beta.show()