In [4]:
import findspark
findspark.init()
import pyspark

In [5]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Python Spark").master('local').getOrCreate()

In [5]:
sc = spark.sparkContext # for RDD

In [6]:
spark

In [7]:
from pyspark.sql.types import *
import pandas as pd

## Read file with schema

In [8]:
# We specify the correct schema by hand
schema_sdf = StructType([
        StructField('Year', IntegerType(), True),
        StructField('Month', IntegerType(), True),
        StructField('DayofMonth', IntegerType(), True),
        StructField('DayOfWeek', IntegerType(), True),
        StructField('DepTime', DoubleType(), True),
        StructField('CRSDepTime', DoubleType(), True),
        StructField('ArrTime', DoubleType(), True),
        StructField('CRSArrTime', DoubleType(), True),
        StructField('UniqueCarrier', StringType(), True),
        StructField('FlightNum', StringType(), True),
        StructField('TailNum', StringType(), True),
        StructField('ActualElapsedTime', DoubleType(), True),
        StructField('CRSElapsedTime',  DoubleType(), True),
        StructField('AirTime',  DoubleType(), True),
        StructField('ArrDelay',  DoubleType(), True),
        StructField('DepDelay',  DoubleType(), True),
        StructField('Origin', StringType(), True),
        StructField('Dest',  StringType(), True),
        StructField('Distance',  DoubleType(), True),
        StructField('TaxiIn',  DoubleType(), True),
        StructField('TaxiOut',  DoubleType(), True),
        StructField('Cancelled',  IntegerType(), True),
        StructField('CancellationCode',  StringType(), True),
        StructField('Diverted',  IntegerType(), True),
        StructField('CarrierDelay', DoubleType(), True),
        StructField('WeatherDelay',  DoubleType(), True),
        StructField('NASDelay',  DoubleType(), True),
        StructField('SecurityDelay',  DoubleType(), True),
        StructField('LateAircraftDelay',  DoubleType(), True)
    ])
air = spark.read.options(header='true').schema(schema_sdf).csv("/lifeng/student/liutuozhen/airdelay_small.csv")

In [9]:
air.printSchema()

root
 |-- Year: integer (nullable = true)
 |-- Month: integer (nullable = true)
 |-- DayofMonth: integer (nullable = true)
 |-- DayOfWeek: integer (nullable = true)
 |-- DepTime: double (nullable = true)
 |-- CRSDepTime: double (nullable = true)
 |-- ArrTime: double (nullable = true)
 |-- CRSArrTime: double (nullable = true)
 |-- UniqueCarrier: string (nullable = true)
 |-- FlightNum: string (nullable = true)
 |-- TailNum: string (nullable = true)
 |-- ActualElapsedTime: double (nullable = true)
 |-- CRSElapsedTime: double (nullable = true)
 |-- AirTime: double (nullable = true)
 |-- ArrDelay: double (nullable = true)
 |-- DepDelay: double (nullable = true)
 |-- Origin: string (nullable = true)
 |-- Dest: string (nullable = true)
 |-- Distance: double (nullable = true)
 |-- TaxiIn: double (nullable = true)
 |-- TaxiOut: double (nullable = true)
 |-- Cancelled: integer (nullable = true)
 |-- CancellationCode: string (nullable = true)
 |-- Diverted: integer (nullable = true)
 |-- Carrier

In [10]:
air_pd = pd.read_csv("/home/student/student/liutuozhen/airdelay/airdelay_small.csv")

In [13]:
air_pd.dtypes

Year                   int64
Month                  int64
DayofMonth             int64
DayOfWeek              int64
DepTime              float64
CRSDepTime             int64
ArrTime              float64
CRSArrTime             int64
UniqueCarrier         object
FlightNum              int64
TailNum               object
ActualElapsedTime    float64
CRSElapsedTime       float64
AirTime              float64
ArrDelay             float64
DepDelay             float64
Origin                object
Dest                  object
Distance             float64
TaxiIn               float64
TaxiOut              float64
Cancelled              int64
CancellationCode      object
Diverted               int64
CarrierDelay         float64
WeatherDelay         float64
NASDelay             float64
SecurityDelay        float64
LateAircraftDelay    float64
dtype: object

## Count

In [21]:
air.count()

21/12/01 20:11:36 WARN Utils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.
                                                                                

5548754

In [16]:
air_pd.count()

Year                 5548754
Month                5548754
DayofMonth           5548754
DayOfWeek            5548754
DepTime              5445648
CRSDepTime           5548754
ArrTime              5432958
CRSArrTime           5548754
UniqueCarrier        5548754
FlightNum            5548754
TailNum              3771786
ActualElapsedTime    5432958
CRSElapsedTime       5547553
AirTime              3685725
ArrDelay             5432958
DepDelay             5445648
Origin               5548754
Dest                 5548754
Distance             5539053
TaxiIn               3774512
TaxiOut              3774512
Cancelled            5548754
CancellationCode       28528
Diverted             5548754
CarrierDelay         1556218
WeatherDelay         1556218
NASDelay             1556218
SecurityDelay        1556218
LateAircraftDelay    1556218
dtype: int64

## Descriptive Statistics

In [11]:
air.describe().show()



+-------+------------------+------------------+------------------+------------------+------------------+-----------------+------------------+------------------+-------------+------------------+-------+------------------+------------------+------------------+-----------------+-----------------+-------+-------+-----------------+------------------+------------------+---------+----------------+--------+------------------+------------------+-----------------+--------------------+-----------------+
|summary|              Year|             Month|        DayofMonth|         DayOfWeek|           DepTime|       CRSDepTime|           ArrTime|        CRSArrTime|UniqueCarrier|         FlightNum|TailNum| ActualElapsedTime|    CRSElapsedTime|           AirTime|         ArrDelay|         DepDelay| Origin|   Dest|         Distance|            TaxiIn|           TaxiOut|Cancelled|CancellationCode|Diverted|      CarrierDelay|      WeatherDelay|         NASDelay|       SecurityDelay|LateAircraftDelay|
+---

                                                                                21/12/01 20:41:58 WARN Utils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.


In [14]:
air_pd.describe()

Unnamed: 0,Year,Month,DayofMonth,DayOfWeek,DepTime,CRSDepTime,ArrTime,CRSArrTime,FlightNum,ActualElapsedTime,...,Distance,TaxiIn,TaxiOut,Cancelled,Diverted,CarrierDelay,WeatherDelay,NASDelay,SecurityDelay,LateAircraftDelay
count,5548754.0,5548754.0,5548754.0,5548754.0,5445648.0,5548754.0,5432958.0,5548754.0,5548754.0,5432958.0,...,5539053.0,3774512.0,3774512.0,5548754.0,5548754.0,1556218.0,1556218.0,1556218.0,1556218.0,1556218.0
mean,1998.06,6.567632,15.72343,3.941769,1350.179,1335.733,1494.26,1491.223,1315.705,119.6707,...,700.157,6.464043,15.0937,0.01858183,0.002287,3.126645,0.6756155,3.483191,0.02487633,4.016734
std,5.959655,3.445993,8.785812,1.990137,476.997,476.8681,498.5111,493.9483,1348.43,68.46377,...,550.4983,24.09104,11.1247,0.1350428,0.04776788,18.23183,8.620073,15.13413,1.128987,18.52165
min,1987.0,1.0,1.0,1.0,1.0,0.0,1.0,0.0,1.0,-681.0,...,11.0,0.0,0.0,0.0,0.0,0.0,0.0,-13.0,0.0,0.0
25%,1993.0,4.0,8.0,2.0,935.0,930.0,1117.0,1115.0,445.0,70.0,...,306.0,4.0,10.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,1998.0,7.0,16.0,4.0,1335.0,1330.0,1522.0,1520.0,923.0,101.0,...,543.0,5.0,13.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,2003.0,10.0,23.0,6.0,1739.0,1730.0,1918.0,1913.0,1671.0,151.0,...,936.0,7.0,18.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,2007.0,12.0,31.0,7.0,2644.0,2400.0,2742.0,2400.0,9899.0,1766.0,...,4983.0,1470.0,1439.0,1.0,1.0,1665.0,910.0,1010.0,382.0,1060.0


In [24]:
air_pd.describe()
# pandas count 包含NA

Unnamed: 0,Year,Month,DayofMonth,DayOfWeek,DepTime,CRSDepTime,ArrTime,CRSArrTime,FlightNum,ActualElapsedTime,...,Distance,TaxiIn,TaxiOut,Cancelled,Diverted,CarrierDelay,WeatherDelay,NASDelay,SecurityDelay,LateAircraftDelay
count,5548754.0,5548754.0,5548754.0,5548754.0,5445648.0,5548754.0,5432958.0,5548754.0,5548754.0,5432958.0,...,5539053.0,3774512.0,3774512.0,5548754.0,5548754.0,1556218.0,1556218.0,1556218.0,1556218.0,1556218.0
mean,1998.06,6.567632,15.72343,3.941769,1350.179,1335.733,1494.26,1491.223,1315.705,119.6707,...,700.157,6.464043,15.0937,0.01858183,0.002287,3.126645,0.6756155,3.483191,0.02487633,4.016734
std,5.959655,3.445993,8.785812,1.990137,476.997,476.8681,498.5111,493.9483,1348.43,68.46377,...,550.4983,24.09104,11.1247,0.1350428,0.04776788,18.23183,8.620073,15.13413,1.128987,18.52165
min,1987.0,1.0,1.0,1.0,1.0,0.0,1.0,0.0,1.0,-681.0,...,11.0,0.0,0.0,0.0,0.0,0.0,0.0,-13.0,0.0,0.0
25%,1993.0,4.0,8.0,2.0,935.0,930.0,1117.0,1115.0,445.0,70.0,...,306.0,4.0,10.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,1998.0,7.0,16.0,4.0,1335.0,1330.0,1522.0,1520.0,923.0,101.0,...,543.0,5.0,13.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,2003.0,10.0,23.0,6.0,1739.0,1730.0,1918.0,1913.0,1671.0,151.0,...,936.0,7.0,18.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,2007.0,12.0,31.0,7.0,2644.0,2400.0,2742.0,2400.0,9899.0,1766.0,...,4983.0,1470.0,1439.0,1.0,1.0,1665.0,910.0,1010.0,382.0,1060.0


## select

In [26]:
air.select(air['UniqueCarrier'], air['ArrDelay']>0).show()

+-------------+--------------+
|UniqueCarrier|(ArrDelay > 0)|
+-------------+--------------+
|           XE|          true|
|           CO|          true|
|           AA|          true|
|           WN|         false|
|           CO|          true|
|           AA|          true|
|           DL|         false|
|           AA|          true|
|           US|          true|
|           AA|          true|
|           AS|         false|
|           UA|          true|
|           TW|         false|
|           NW|          true|
|           NW|          true|
|           AA|          true|
|           DH|          true|
|           WN|         false|
|           AA|         false|
|           CO|          true|
+-------------+--------------+
only showing top 20 rows



In [34]:
pd.concat([air_pd["UniqueCarrier"], air_pd["ArrDelay"].map(lambda x: x>0)], axis=1)

Unnamed: 0,UniqueCarrier,ArrDelay
0,XE,True
1,CO,True
2,AA,True
3,WN,False
4,CO,True
...,...,...
5548749,WN,True
5548750,WN,True
5548751,AA,True
5548752,NW,False


## Group and sort

In [11]:
aircount=air.groupBy("UniqueCarrier").count()
aircount.sort("count",ascending=False).show()

21/12/01 19:44:04 WARN Utils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.

+-------------+------+
|UniqueCarrier| count|
+-------------+------+
|           DL|765388|
|           WN|703368|
|           AA|684522|
|           US|649056|
|           UA|611957|
|           NW|473820|
|           CO|373858|
|           TW|179081|
|           HP|173509|
|           MQ|164790|
|           AS|129863|
|           OO|120223|
|           XE| 94311|
|           EV| 67148|
|           OH| 60630|
|           FL| 47540|
|           EA| 43723|
|           PI| 41489|
|           DH| 32900|
|           B6| 29111|
+-------------+------+
only showing top 20 rows



                                                                                

In [33]:
air_pd.groupby("UniqueCarrier")["UniqueCarrier"].count().sort_values(ascending=False)

UniqueCarrier
DL        765388
WN        703368
AA        684522
US        649056
UA        611957
NW        473820
CO        373858
TW        179081
HP        173509
MQ        164790
AS        129863
OO        120223
XE         94311
EV         67148
OH         60630
FL         47540
EA         43723
PI         41489
DH         32900
B6         29111
YV         28764
PA (1)     15213
9E         12203
F9         11527
HA         10228
TZ          9955
AQ          7142
PS          4059
ML (1)      3376
Name: UniqueCarrier, dtype: int64

## NA

In [20]:
air[["SecurityDelay"]].na.drop().count()

                                                                                

1556218

In [21]:
air_pd[["SecurityDelay"]].dropna().count()

SecurityDelay    1556218
dtype: int64

In [29]:
air[["SecurityDelay"]].na.fill(-1).show()

+-------------+
|SecurityDelay|
+-------------+
|          0.0|
|         -1.0|
|         -1.0|
|         -1.0|
|         -1.0|
|         -1.0|
|         -1.0|
|         -1.0|
|         -1.0|
|         -1.0|
|          0.0|
|         -1.0|
|         -1.0|
|         -1.0|
|          0.0|
|          0.0|
|          0.0|
|         -1.0|
|         -1.0|
|         -1.0|
+-------------+
only showing top 20 rows



In [31]:
air_pd[["SecurityDelay"]].fillna(-1)

Unnamed: 0,SecurityDelay
0,0.0
1,-1.0
2,-1.0
3,-1.0
4,-1.0
...,...
5548749,-1.0
5548750,-1.0
5548751,0.0
5548752,-1.0


## agg

In [43]:
air.groupBy("UniqueCarrier").max('ArrDelay').sort("max(ArrDelay)", ascending=False).show() # mus apply to a numeric column



+-------------+-------------+
|UniqueCarrier|max(ArrDelay)|
+-------------+-------------+
|           NW|       1779.0|
|           OO|       1435.0|
|           AA|       1425.0|
|           9E|       1425.0|
|           HP|       1323.0|
|           UA|       1201.0|
|           MQ|       1082.0|
|           EV|       1029.0|
|           OH|        980.0|
|           B6|        960.0|
|           FL|        945.0|
|           CO|        939.0|
|           XE|        852.0|
|           HA|        726.0|
|           YV|        693.0|
|           AS|        683.0|
|       PA (1)|        680.0|
|           DL|        675.0|
|           US|        623.0|
|           WN|        616.0|
+-------------+-------------+
only showing top 20 rows



                                                                                

In [44]:
air_pd.groupby("UniqueCarrier")['ArrDelay'].max().sort_values(ascending=False)

UniqueCarrier
NW        1779.0
OO        1435.0
9E        1425.0
AA        1425.0
HP        1323.0
UA        1201.0
MQ        1082.0
EV        1029.0
OH         980.0
B6         960.0
FL         945.0
CO         939.0
XE         852.0
HA         726.0
YV         693.0
AS         683.0
PA (1)     680.0
DL         675.0
US         623.0
WN         616.0
TZ         614.0
TW         603.0
AQ         600.0
DH         553.0
EA         542.0
F9         393.0
PI         350.0
ML (1)     335.0
PS         283.0
Name: ArrDelay, dtype: float64

## filter

In [24]:
air.filter(air.ArrDelay>60).show()

+----+-----+----------+---------+-------+----------+-------+----------+-------------+---------+-------+-----------------+--------------+-------+--------+--------+------+----+--------+------+-------+---------+----------------+--------+------------+------------+--------+-------------+-----------------+
|Year|Month|DayofMonth|DayOfWeek|DepTime|CRSDepTime|ArrTime|CRSArrTime|UniqueCarrier|FlightNum|TailNum|ActualElapsedTime|CRSElapsedTime|AirTime|ArrDelay|DepDelay|Origin|Dest|Distance|TaxiIn|TaxiOut|Cancelled|CancellationCode|Diverted|CarrierDelay|WeatherDelay|NASDelay|SecurityDelay|LateAircraftDelay|
+----+-----+----------+---------+-------+----------+-------+----------+-------------+---------+-------+-----------------+--------------+-------+--------+--------+------+----+--------+------+-------+---------+----------------+--------+------------+------------+--------+-------------+-----------------+
|2007|    5|        17|        4| 1944.0|    1830.0| 2034.0|    1930.0|           AA|     1011

In [46]:
air_pd[air_pd.ArrDelay>60]

Unnamed: 0,Year,Month,DayofMonth,DayOfWeek,DepTime,CRSDepTime,ArrTime,CRSArrTime,UniqueCarrier,FlightNum,...,TaxiIn,TaxiOut,Cancelled,CancellationCode,Diverted,CarrierDelay,WeatherDelay,NASDelay,SecurityDelay,LateAircraftDelay
15,2007,5,17,4,1944.0,1830,2034.0,1930,AA,1011,...,4.0,11.0,0,,0,64.0,0.0,0.0,0.0,0.0
46,2001,5,31,4,1812.0,1650,1954.0,1825,WN,1612,...,3.0,13.0,0,,0,,,,,
78,1999,8,13,5,1800.0,1635,2000.0,1826,DL,965,...,10.0,13.0,0,,0,,,,,
96,2004,4,16,5,1550.0,1525,1750.0,1638,XE,2602,...,5.0,80.0,0,,0,0.0,0.0,51.0,0.0,21.0
107,1999,8,17,2,1246.0,1135,1430.0,1310,CO,1843,...,6.0,27.0,0,,0,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5548605,2005,10,10,1,1518.0,1305,1552.0,1343,OO,6216,...,1.0,10.0,0,,0,129.0,0.0,0.0,0.0,0.0
5548611,1995,3,4,6,1706.0,0,1831.0,0,AA,1087,...,3.0,12.0,0,,0,,,,,
5548653,2000,8,20,7,1459.0,1144,1706.0,1356,UA,643,...,7.0,22.0,0,,0,,,,,
5548677,1997,8,24,7,1030.0,922,1504.0,1403,HP,2273,...,4.0,14.0,0,,0,,,,,


## User-defined functions

In [25]:
air2_pdf = air.select(["DayOfWeek", "ArrDelay","AirTime","Distance"]).toPandas()

                                                                                

In [26]:
air2_pdf.head()

Unnamed: 0,DayOfWeek,ArrDelay,AirTime,Distance
0,4.0,2.0,25.0,127.0
1,7.0,29.0,248.0,1623.0
2,,,,
3,5.0,-2.0,70.0,451.0
4,7.0,11.0,133.0,1009.0


In [27]:
import pandas as pd

def myfun(pdf):
    out = dict() 
    out["ArrDelay"] = pdf.ArrDelay.mean()
    out["AirTime"]  = pdf.AirTime.mean()
    out["Distance"] = pdf.Distance.mean()
    
    return pd.DataFrame(out, index=[0])

myfun(air2_pdf)

Unnamed: 0,ArrDelay,AirTime,Distance
0,7.350591,102.688519,729.997977


In [None]:
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType

@pandas_udf("DayOfWeek long, ArrDelay long", PandasUDFType.GROUPED_MAP)   # , AirTime long, Distance long
def myfun(pdf):
    out = dict() 
    out["ArrDelay"] = pdf.ArrDelay.mean()
#     out["AirTime"]  = pdf.AirTime.mean()
#     out["Distance"] = pdf.Distance.mean()
    return pd.Dataframe(out, index=[0])

In [29]:
air2 = air.select(["DayOfWeek","ArrDelay","AirTime","Distance"])# .filter(air.Distance<100)

In [30]:
air2.show()

+---------+--------+-------+--------+
|DayOfWeek|ArrDelay|AirTime|Distance|
+---------+--------+-------+--------+
|        4|     2.0|   25.0|   127.0|
|        7|    29.0|  248.0|  1623.0|
|     null|    null|   null|    null|
|        5|    -2.0|   70.0|   451.0|
|        7|    11.0|  133.0|  1009.0|
|        7|    13.0|  177.0|  1562.0|
|        1|   -12.0|  181.0|  1589.0|
|        3|    11.0|  364.0|  2611.0|
|        5|    13.0|   53.0|   304.0|
|     null|    null|   null|    null|
|        5|    -8.0|  293.0|  2537.0|
|     null|    null|   null|    null|
|     null|    null|   null|    null|
|        2|    55.0|  285.0|  1927.0|
|        1|    23.0|  149.0|   991.0|
|        4|    64.0|   35.0|   193.0|
|        4|    29.0|   25.0|    77.0|
|     null|    null|   null|    null|
|        7|    -6.0|   91.0|   678.0|
|        7|    35.0|  127.0|   998.0|
+---------+--------+-------+--------+
only showing top 20 rows



In [None]:
air3.show()

In [31]:
air3 = air2.na.drop()
air3.select(["DayOfWeek", "ArrDelay"]).groupby("DayOfWeek").apply(myfun).show()

21/12/01 19:55:07 WARN TaskSetManager: Lost task 24.0 in stage 28.0 (TID 478, emr-worker-2.cluster-46968, executor 2): java.lang.IllegalArgumentException
	at java.nio.ByteBuffer.allocate(ByteBuffer.java:334)
	at org.apache.arrow.vector.ipc.message.MessageSerializer.readMessage(MessageSerializer.java:543)
	at org.apache.arrow.vector.ipc.message.MessageChannelReader.readNext(MessageChannelReader.java:58)
	at org.apache.arrow.vector.ipc.ArrowStreamReader.readSchema(ArrowStreamReader.java:132)
	at org.apache.arrow.vector.ipc.ArrowReader.initialize(ArrowReader.java:181)
	at org.apache.arrow.vector.ipc.ArrowReader.ensureInitialized(ArrowReader.java:172)
	at org.apache.arrow.vector.ipc.ArrowReader.getVectorSchemaRoot(ArrowReader.java:65)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:162)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
	at org.apache.spark.api.python.BasePythonRunner$Reader

Py4JJavaError: An error occurred while calling o241.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 24 in stage 28.0 failed 4 times, most recent failure: Lost task 24.3 in stage 28.0 (TID 486, emr-worker-1.cluster-46968, executor 1): java.lang.IllegalArgumentException
	at java.nio.ByteBuffer.allocate(ByteBuffer.java:334)
	at org.apache.arrow.vector.ipc.message.MessageSerializer.readMessage(MessageSerializer.java:543)
	at org.apache.arrow.vector.ipc.message.MessageChannelReader.readNext(MessageChannelReader.java:58)
	at org.apache.arrow.vector.ipc.ArrowStreamReader.readSchema(ArrowStreamReader.java:132)
	at org.apache.arrow.vector.ipc.ArrowReader.initialize(ArrowReader.java:181)
	at org.apache.arrow.vector.ipc.ArrowReader.ensureInitialized(ArrowReader.java:172)
	at org.apache.arrow.vector.ipc.ArrowReader.getVectorSchemaRoot(ArrowReader.java:65)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:162)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:410)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:255)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:858)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:858)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:123)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1891)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1879)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1878)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1878)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:927)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:927)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:927)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2112)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2061)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2050)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:738)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3389)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2550)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2550)
	at org.apache.spark.sql.Dataset$$anonfun$52.apply(Dataset.scala:3370)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:80)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:127)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:75)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3369)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2550)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2764)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:291)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.IllegalArgumentException
	at java.nio.ByteBuffer.allocate(ByteBuffer.java:334)
	at org.apache.arrow.vector.ipc.message.MessageSerializer.readMessage(MessageSerializer.java:543)
	at org.apache.arrow.vector.ipc.message.MessageChannelReader.readNext(MessageChannelReader.java:58)
	at org.apache.arrow.vector.ipc.ArrowStreamReader.readSchema(ArrowStreamReader.java:132)
	at org.apache.arrow.vector.ipc.ArrowReader.initialize(ArrowReader.java:181)
	at org.apache.arrow.vector.ipc.ArrowReader.ensureInitialized(ArrowReader.java:172)
	at org.apache.arrow.vector.ipc.ArrowReader.getVectorSchemaRoot(ArrowReader.java:65)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:162)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:410)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:255)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:858)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:858)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:123)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more


In [24]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.stat import Correlation

data = [(Vectors.sparse(4, [(0, 1.0), (3, -2.0)]),),
        (Vectors.dense([4.0, 5.0, 0.0, 3.0]),),
        (Vectors.dense([6.0, 7.0, 0.0, 8.0]),),
        (Vectors.sparse(4, [(0, 9.0), (3, 1.0)]),)]
df = spark.createDataFrame(data, ["features"])


In [27]:
df.show()

+--------------------+
|            features|
+--------------------+
|(4,[0,3],[1.0,-2.0])|
|   [4.0,5.0,0.0,3.0]|
|   [6.0,7.0,0.0,8.0]|
| (4,[0,3],[9.0,1.0])|
+--------------------+



In [28]:
r1 = Correlation.corr(df, "features").head()
print("Pearson correlation matrix:\n" + str(r1[0]))

21/11/14 13:43:23 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
21/11/14 13:43:23 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS


Pearson correlation matrix:
DenseMatrix([[1.        , 0.05564149,        nan, 0.40047142],
             [0.05564149, 1.        ,        nan, 0.91359586],
             [       nan,        nan, 1.        ,        nan],
             [0.40047142, 0.91359586,        nan, 1.        ]])


21/11/14 13:43:23 WARN PearsonCorrelation: Pearson correlation matrix contains NaN values.
