In [1]:
import pyspark
from pyspark.sql import SparkSession
import warnings
warnings.filterwarnings("ignore")

spark = SparkSession.builder.config("spark.driver.memory", "4g").appName('MCS').getOrCreate()



In [2]:
stocks = spark.read.csv(["data/stocks/ABAX.csv","data/stocks/AAME.csv","data/stocks/AEPI.csv"], header='true', inferSchema='true')
stocks.show(2)

+---------+-----+-----+-----+-----+------+
|     Date| Open| High|  Low|Close|Volume|
+---------+-----+-----+-----+-----+------+
|31-Dec-13|52.94|54.37|52.25|52.83| 79429|
|30-Dec-13|50.36|54.10|50.36|52.95|131095|
+---------+-----+-----+-----+-----+------+
only showing top 2 rows



In [3]:
print((stocks.count(), len(stocks.columns)))

from pyspark.sql import functions as fun
stocks = stocks.withColumn("Symbol", fun.input_file_name()).withColumn("Symbol",fun.element_at(fun.split("Symbol", "/"), -1)).withColumn("Symbol",fun.element_at(fun.split("Symbol", "\."), 1))
stocks.show(2)

(10546, 6)
+---------+-----+-----+-----+-----+------+------+
|     Date| Open| High|  Low|Close|Volume|Symbol|
+---------+-----+-----+-----+-----+------+------+
|31-Dec-13|52.94|54.37|52.25|52.83| 79429|  AEPI|
|30-Dec-13|50.36|54.10|50.36|52.95|131095|  AEPI|
+---------+-----+-----+-----+-----+------+------+
only showing top 2 rows



In [4]:
factors = spark.read.csv(["data/stocks/ABAX.csv","data/stocks/AAME.csv","data/stocks/AEPI.csv"], header='true', inferSchema='true')
factors = factors.withColumn("Symbol", fun.input_file_name()).withColumn("Symbol",fun.element_at(fun.split("Symbol", "/"), -1)).\
withColumn("Symbol",fun.element_at(fun.split("Symbol", "\."), 1))

In [5]:
from pyspark.sql import Window
stocks = stocks.withColumn('count', fun.count('Symbol').over(Window.partitionBy('Symbol'))).filter(fun.col('count') > 260*5 + 10)
spark.sql("set spark.sql.legacy.timeParserPolicy=LEGACY")
stocks = stocks.withColumn('Date',fun.to_date(fun.to_timestamp(fun.col('Date'),'dd-MMM-yy')))
stocks.printSchema()

root
 |-- Date: date (nullable = true)
 |-- Open: string (nullable = true)
 |-- High: string (nullable = true)
 |-- Low: string (nullable = true)
 |-- Close: double (nullable = true)
 |-- Volume: integer (nullable = true)
 |-- Symbol: string (nullable = true)
 |-- count: long (nullable = false)



In [6]:
from datetime import datetime
stocks = stocks.filter(fun.col('Date') >= datetime(2009, 10, 23)).filter(fun.col('Date') <= datetime(2014, 10, 23))

print((stocks.count(), len(stocks.columns)))

(3145, 8)


In [7]:
factors = factors.withColumn('Date',
fun.to_date(fun.to_timestamp(fun.col('Date'),'dd-MMM-yy')))
factors = factors.filter(fun.col('Date') >= datetime(2009, 10, 23)).filter(fun.col('Date') <= datetime(2014, 10, 23))
factors.show(2)

+----------+-----+-----+-----+-----+------+------+
|      Date| Open| High|  Low|Close|Volume|Symbol|
+----------+-----+-----+-----+-----+------+------+
|2013-12-31|52.94|54.37|52.25|52.83| 79429|  AEPI|
|2013-12-30|50.36|54.10|50.36|52.95|131095|  AEPI|
+----------+-----+-----+-----+-----+------+------+
only showing top 2 rows



In [8]:
stocks_pd_df = stocks.toPandas()
factors_pd_df = factors.toPandas()
factors_pd_df.head(5)

Unnamed: 0,Date,Open,High,Low,Close,Volume,Symbol
0,2013-12-31,52.94,54.37,52.25,52.83,79429,AEPI
1,2013-12-30,50.36,54.1,50.36,52.95,131095,AEPI
2,2013-12-27,50.38,50.8,49.67,50.52,54354,AEPI
3,2013-12-26,50.5,51.19,49.67,50.0,74414,AEPI
4,2013-12-24,49.85,50.6,49.66,49.99,36872,AEPI


In [9]:
n_steps = 10
def my_fun(x):
    return ((x.iloc[-1] - x.iloc[0]) / x.iloc[0])
stock_returns = stocks_pd_df.groupby('Symbol').Close.rolling(window=n_steps).apply(my_fun)
factors_returns = factors_pd_df.groupby('Symbol').Close.rolling(window=n_steps).apply(my_fun)
stock_returns = stock_returns.reset_index().sort_values('level_1').reset_index()
factors_returns = factors_returns.reset_index().sort_values('level_1').reset_index()

In [10]:
stocks_pd_df_with_returns = stocks_pd_df.assign(stock_returns = stock_returns['Close'])
factors_pd_df_with_returns = factors_pd_df.assign(factors_returns = \
factors_returns['Close'],factors_returns_squared = factors_returns['Close']**2)
factors_pd_df_with_returns = factors_pd_df_with_returns.pivot(index='Date',columns='Symbol',values=['factors_returns', 'factors_returns_squared'])
factors_pd_df_with_returns.columns = factors_pd_df_with_returns.columns.to_series().str.join('_').reset_index()[0]
factors_pd_df_with_returns = factors_pd_df_with_returns.reset_index()
print(factors_pd_df_with_returns.head(1))

0        Date  factors_returns_AAME  factors_returns_ABAX  \
0  2009-10-23              0.111111               0.10064   

0  factors_returns_AEPI  factors_returns_squared_AAME  \
0              0.097994                      0.012346   

0  factors_returns_squared_ABAX  factors_returns_squared_AEPI  
0                      0.010128                      0.009603  


In [11]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
from sklearn.linear_model import LinearRegression
# For each stock, create input DF for linear regression training
stocks_factors_combined_df = pd.merge(stocks_pd_df_with_returns,factors_pd_df_with_returns,how="left", on="Date")
feature_columns = list(stocks_factors_combined_df.columns[-6:])

with pd.option_context('mode.use_inf_as_na', True):
    stocks_factors_combined_df = stocks_factors_combined_df.dropna(subset=feature_columns + ['stock_returns'])

def find_ols_coef(df):
    y = df[['stock_returns']].values
    X = df[feature_columns]
    regr = LinearRegression()
    regr_output = regr.fit(X, y)
    return list(df[['Symbol']].values[0]) + list(regr_output.coef_[0])

coefs_per_stock = stocks_factors_combined_df.groupby('Symbol').apply(find_ols_coef)
coefs_per_stock = pd.DataFrame(coefs_per_stock).reset_index()
coefs_per_stock.columns = ['symbol', 'factor_coef_list']
coefs_per_stock = pd.DataFrame(coefs_per_stock.factor_coef_list.tolist(),index=coefs_per_stock.index,columns = ['Symbol'] + feature_columns)
coefs_per_stock

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, positive=False):
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  max_n_alphas=1000, n_jobs=1, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  max_n_alphas=1

Unnamed: 0,Symbol,factors_returns_AAME,factors_returns_ABAX,factors_returns_AEPI,factors_returns_squared_AAME,factors_returns_squared_ABAX,factors_returns_squared_AEPI
0,AAME,1.0,-1.240376e-15,9.220346e-16,1.607635e-16,-9.631756000000001e-17,8.117732000000001e-17
1,ABAX,5.729698e-16,1.0,-4.242853e-16,5.503443e-19,-3.6079320000000004e-17,2.408654e-17
2,AEPI,-3.140315e-16,9.080787e-17,1.0,-9.744513000000001e-17,7.967819000000001e-17,-7.846324e-17


In [12]:
samples = factors_returns.loc[factors_returns.Symbol == \
factors_returns.Symbol.unique()[0]]['Close']
samples.plot.kde()

<matplotlib.axes._subplots.AxesSubplot at 0x7f8356f03c18>

In [13]:
f_1 = factors_returns.loc[factors_returns.Symbol == factors_returns.Symbol.unique()[0]]['Close']
f_2 = factors_returns.loc[factors_returns.Symbol == factors_returns.Symbol.unique()[1]]['Close']
f_3 = factors_returns.loc[factors_returns.Symbol == factors_returns.Symbol.unique()[2]]['Close']
print(f_1.size,len(f_2),f_3.size)
pd.DataFrame({'f1': list(f_1)[1:1040], 'f2': list(f_2)[1:1040], 'f3':list(f_3)}).corr()

1053 1053 1039


Unnamed: 0,f1,f2,f3
f1,1.0,0.275057,-0.015415
f2,0.275057,1.0,0.03137
f3,-0.015415,0.03137,1.0


In [14]:
factors_returns_cov = pd.DataFrame({'f1': list(f_1)[1:1040],'f2': list(f_2)[1:1040],'f3': list(f_3)}).cov().to_numpy()
factors_returns_mean = pd.DataFrame({'f1': list(f_1)[1:1040],'f2': list(f_2)[1:1040],'f3': list(f_3)}).mean()

from numpy.random import multivariate_normal
multivariate_normal(factors_returns_mean, factors_returns_cov)

array([0.03773615, 0.00190841, 0.05145688])

In [15]:
b_coefs_per_stock = spark.sparkContext.broadcast(coefs_per_stock)
b_feature_columns = spark.sparkContext.broadcast(feature_columns)
b_factors_returns_mean = spark.sparkContext.broadcast(factors_returns_mean)
b_factors_returns_cov = spark.sparkContext.broadcast(factors_returns_cov)

In [16]:
from pyspark.sql.types import IntegerType

parallelism = 1000
num_trials = 1000000
base_seed = 1496
seeds = [b for b in range(base_seed,
base_seed + parallelism)]
seedsDF = spark.createDataFrame(seeds, IntegerType())
seedsDF = seedsDF.repartition(parallelism)
import random
from numpy.random import seed
from pyspark.sql.types import LongType, ArrayType, DoubleType
from pyspark.sql.functions import udf

def calculate_trial_return(x):
    trial_return_list = []
    for i in range(int(num_trials/parallelism)):
        random_int = random.randint(0, num_trials*num_trials)
        seed(x)
        random_factors = multivariate_normal(b_factors_returns_mean.value,
        b_factors_returns_cov.value)
        coefs_per_stock_df = b_coefs_per_stock.value
        returns_per_stock = (coefs_per_stock_df[b_feature_columns.value] *(list(random_factors) + list(random_factors**2)))
        trial_return_list.append(float(returns_per_stock.sum(axis=1).sum(), b_coefs_per_stock.value.size))
    return trial_return_list

udf_return = udf(calculate_trial_return, ArrayType(DoubleType()))

def calculate_trial_return(x):
    trial_return_list = []
    for i in range(int(num_trials / parallelism)):
        random_int = random.randint(0, num_trials * num_trials)
        seed(x)
        random_factors = multivariate_normal(b_factors_returns_mean.value, b_factors_returns_cov.value)
        coefs_per_stock_df = b_coefs_per_stock.value
        returns_per_stock = (coefs_per_stock_df[b_feature_columns.value] * (list(random_factors) + list(random_factors**2)))
        trial_return_list.append(float(returns_per_stock.sum(axis=1).sum()) / b_coefs_per_stock.value.size)
        trial_return_list.append(float(returns_per_stock.sum(axis=1).sum()) / b_coefs_per_stock.value.size)
    return trial_return_list
udf_return = udf(calculate_trial_return, ArrayType(DoubleType()))

In [17]:
from pyspark.sql.functions import col, explode
trials = seedsDF.withColumn("trial_return", udf_return(col("value")))
trials = trials.select('value', explode('trial_return').alias('trial_return'))
trials.cache()

DataFrame[value: int, trial_return: double]

In [18]:
trials.approxQuantile('trial_return', [0.05], 0.0)

Py4JJavaError: An error occurred while calling o167.approxQuantile.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 11 in stage 20.0 failed 1 times, most recent failure: Lost task 11.0 in stage 20.0 (TID 38) (lpcp-24 executor driver): net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype). This happens when an unsupported/unregistered class is being unpickled that requires construction arguments. Fix it by registering a custom IObjectConstructor for this class.
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:759)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:199)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:109)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:122)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec.$anonfun$evaluate$6(BatchEvalPythonExec.scala:95)
	at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.sort(UnsafeExternalRowSorter.java:225)
	at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$10(ShuffleExchangeExec.scala:369)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:888)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:888)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:101)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:139)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:840)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2785)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2721)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2720)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2720)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1206)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1206)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1206)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2984)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2923)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2912)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:971)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2263)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2358)
	at org.apache.spark.rdd.RDD.$anonfun$fold$1(RDD.scala:1172)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:405)
	at org.apache.spark.rdd.RDD.fold(RDD.scala:1166)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$2(RDD.scala:1259)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:405)
	at org.apache.spark.rdd.RDD.treeAggregate(RDD.scala:1226)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$1(RDD.scala:1212)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:405)
	at org.apache.spark.rdd.RDD.treeAggregate(RDD.scala:1212)
	at org.apache.spark.sql.execution.stat.StatFunctions$.multipleApproxQuantiles(StatFunctions.scala:100)
	at org.apache.spark.sql.DataFrameStatFunctions.approxQuantile(DataFrameStatFunctions.scala:104)
	at org.apache.spark.sql.DataFrameStatFunctions.approxQuantile(DataFrameStatFunctions.scala:115)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:568)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:840)
Caused by: net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype). This happens when an unsupported/unregistered class is being unpickled that requires construction arguments. Fix it by registering a custom IObjectConstructor for this class.
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:759)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:199)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:109)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:122)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec.$anonfun$evaluate$6(BatchEvalPythonExec.scala:95)
	at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.sort(UnsafeExternalRowSorter.java:225)
	at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$10(ShuffleExchangeExec.scala:369)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:888)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:888)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:101)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:139)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	... 1 more


In [None]:
trials.orderBy(col('trial_return').asc()).limit(int(trials.count()/20)).agg(fun.avg(col("trial_return"))).show()