In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

24/05/11 21:14:17 WARN Utils: Your hostname, Barts-Mac.local resolves to a loopback address: 127.0.0.1; using 192.168.0.10 instead (on interface en0)
24/05/11 21:14:17 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/11 21:14:17 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
# Creating a PySpark dataframe from a list of rows: (without schema information)
import pandas as pd
from datetime import datetime, date
from pyspark.sql import Row

df = spark.createDataFrame([
    Row(a=1, b=2., c='string1', d=date(2024, 5, 11), e=datetime(2024, 5, 11, 18, 50, 51)),
    Row(a=2, b=3., c='string2', d=date(2024, 5, 12), e=datetime(2024, 5, 12, 12, 22, 45)),
    Row(a=3, b=4., c='string3', d=date(2024, 5, 13), e=datetime(2024, 5, 13, 22, 1, 51)),
    Row(a=4, b=5., c='string4', d=date(2024, 5, 14), e=datetime(2024, 5, 14, 0, 0)),
])
df

DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]

In [3]:
# Creating a PySpark dataframe with schema:

df = spark.createDataFrame([
    (1, 2., 'string1', date(2024, 5, 11), datetime(2024, 5, 11, 18, 50, 51)),
    (2, 3., 'string2', date(2024, 5, 12), datetime(2024, 5, 12, 12, 22, 45)),
    (3, 4., 'string3', date(2024, 5, 13), datetime(2024, 5, 13, 22, 1, 51)),
    (4, 5., 'string4', date(2024, 5, 14), datetime(2024, 5, 14, 0, 0)),
],
    schema='a long, b double, c string, d date, e timestamp')
df

DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]

In [4]:
# Creating a PySpark dataframe from pandas dataframe:

pandas_df = pd.DataFrame({
    'a': [1, 2, 3, 5],
    'b': [5., 6., 7., 8.],
    'c': ['string1', 'string2', 'string3', 'string4'],
    'd': [date(2024, 5, 15), date(2024, 5, 16), date(2024, 5, 17), date(2024, 5, 18)],
    'e': [datetime(2024, 5, 15, 18, 15, 51), datetime(2024, 5, 13, 22, 1, 51), datetime(2024, 5, 22, 11, 2, 2), datetime(2024, 5, 31, 3, 13, 3)]
})
df = spark.createDataFrame(pandas_df)
df

  if should_localize and is_datetime64tz_dtype(s.dtype) and s.dt.tz is not None:


DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]

In [5]:
# Create a PySpark dataframe from an RDD:

rdd = spark.sparkContext.parallelize([
    (1, 2., 'string1', date(2024, 5, 11), datetime(2024, 5, 11, 18, 50, 51)),
    (2, 3., 'string2', date(2024, 5, 12), datetime(2024, 5, 12, 12, 22, 45)),
    (3, 4., 'string3', date(2024, 5, 13), datetime(2024, 5, 13, 22, 1, 51)),
    (4, 5., 'string4', date(2024, 5, 14), datetime(2024, 5, 14, 0, 0)),
])
df = spark.createDataFrame(rdd, schema=['a', 'b', 'c', 'd', 'e'])
df


                                                                                

DataFrame[a: bigint, b: double, c: string, d: date, e: timestamp]

In [6]:
df.show()

+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  1|2.0|string1|2024-05-11|2024-05-11 18:50:51|
|  2|3.0|string2|2024-05-12|2024-05-12 12:22:45|
|  3|4.0|string3|2024-05-13|2024-05-13 22:01:51|
|  4|5.0|string4|2024-05-14|2024-05-14 00:00:00|
+---+---+-------+----------+-------------------+



In [7]:
df.printSchema()

root
 |-- a: long (nullable = true)
 |-- b: double (nullable = true)
 |-- c: string (nullable = true)
 |-- d: date (nullable = true)
 |-- e: timestamp (nullable = true)



In [8]:
# Configuration of PySpark dataframes representation:

spark.conf.set('spark.sql.repl.eagerEval.enabled', True)
spark.conf.set('spark.sql.repl.eagerEval.maxNumRows', 3)
df

a,b,c,d,e
1,2.0,string1,2024-05-11,2024-05-11 18:50:51
2,3.0,string2,2024-05-12,2024-05-12 12:22:45
3,4.0,string3,2024-05-13,2024-05-13 22:01:51


In [9]:
df.show(2, vertical=True)  # one can show longer rows in vertical view

-RECORD 0------------------
 a   | 1                   
 b   | 2.0                 
 c   | string1             
 d   | 2024-05-11          
 e   | 2024-05-11 18:50:51 
-RECORD 1------------------
 a   | 2                   
 b   | 3.0                 
 c   | string2             
 d   | 2024-05-12          
 e   | 2024-05-12 12:22:45 
only showing top 2 rows



In [10]:
df.columns

['a', 'b', 'c', 'd', 'e']

In [11]:
df.select(['a', 'b', 'c']).describe().show()  # statistics

+-------+------------------+------------------+-------+
|summary|                 a|                 b|      c|
+-------+------------------+------------------+-------+
|  count|                 4|                 4|      4|
|   mean|               2.5|               3.5|   NULL|
| stddev|1.2909944487358056|1.2909944487358056|   NULL|
|    min|                 1|               2.0|string1|
|    max|                 4|               5.0|string4|
+-------+------------------+------------------+-------+



In [12]:
df.collect()  # this can throw out-of-memory error

[Row(a=1, b=2.0, c='string1', d=datetime.date(2024, 5, 11), e=datetime.datetime(2024, 5, 11, 18, 50, 51)),
 Row(a=2, b=3.0, c='string2', d=datetime.date(2024, 5, 12), e=datetime.datetime(2024, 5, 12, 12, 22, 45)),
 Row(a=3, b=4.0, c='string3', d=datetime.date(2024, 5, 13), e=datetime.datetime(2024, 5, 13, 22, 1, 51)),
 Row(a=4, b=5.0, c='string4', d=datetime.date(2024, 5, 14), e=datetime.datetime(2024, 5, 14, 0, 0))]

In [13]:
# so .take() or .tail() are recommended instead:

df.take(1)

[Row(a=1, b=2.0, c='string1', d=datetime.date(2024, 5, 11), e=datetime.datetime(2024, 5, 11, 18, 50, 51))]

In [14]:
# converting PySpark dataframe to Pandas also uses .collect() under the hood, so the memory issues apply

df.toPandas()

  if not is_datetime64tz_dtype(pser.dtype):
  if is_datetime64tz_dtype(s.dtype):


Unnamed: 0,a,b,c,d,e
0,1,2.0,string1,2024-05-11,2024-05-11 18:50:51
1,2,3.0,string2,2024-05-12,2024-05-12 12:22:45
2,3,4.0,string3,2024-05-13,2024-05-13 22:01:51
3,4,5.0,string4,2024-05-14,2024-05-14 00:00:00


### Selecting data:

In [15]:
df.b  # lazily evaluated, does not return values

Column<'b'>

In [16]:
type(df.b)

pyspark.sql.column.Column

In [17]:
from pyspark.sql import Column
from pyspark.sql.functions import upper

type(df.c) == type(upper(df.c)) == type(df.c.isNull())  # all are of types Column

True

In [18]:
df.select(df.b).show()

+---+
|  b|
+---+
|2.0|
|3.0|
|4.0|
|5.0|
+---+



In [19]:
# assigning new column:

df.withColumn('upper_c', upper(df.c)).show()

+---+---+-------+----------+-------------------+-------+
|  a|  b|      c|         d|                  e|upper_c|
+---+---+-------+----------+-------------------+-------+
|  1|2.0|string1|2024-05-11|2024-05-11 18:50:51|STRING1|
|  2|3.0|string2|2024-05-12|2024-05-12 12:22:45|STRING2|
|  3|4.0|string3|2024-05-13|2024-05-13 22:01:51|STRING3|
|  4|5.0|string4|2024-05-14|2024-05-14 00:00:00|STRING4|
+---+---+-------+----------+-------------------+-------+



In [20]:
df.filter(df.a==2).show()

+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  2|3.0|string2|2024-05-12|2024-05-12 12:22:45|
+---+---+-------+----------+-------------------+



In [21]:
from pyspark.sql.functions import pandas_udf

@pandas_udf('long')
def pandas_plus_1(series: pd.Series) -> pd.Series:
    return series + 1

df.select(pandas_plus_1(df.a)).show()



+----------------+
|pandas_plus_1(a)|
+----------------+
|               2|
|               3|
|               4|
|               5|
+----------------+



In [22]:
def pandas_filter_func(iterator):
    for pandas_df in iterator:
        yield pandas_df[pandas_df.a == 2]

df.mapInPandas(pandas_filter_func, schema = df.schema).show()

+---+---+-------+----------+-------------------+
|  a|  b|      c|         d|                  e|
+---+---+-------+----------+-------------------+
|  2|3.0|string2|2024-05-12|2024-05-12 12:22:45|
+---+---+-------+----------+-------------------+





### Grouping data

In [23]:
df = spark.createDataFrame([
    ['red', 'banana', 1, 10], ['blue', 'banana', 2, 20], ['red', 'carrot', 3, 30],
    ['blue', 'grape', 4, 40], ['red', 'carrot', 5, 50], ['black', 'carrot', 6, 60],
    ['red', 'banana', 7, 70], ['red', 'grape', 8, 80]], schema=['color', 'fruit', 'v1', 'v2'])
df.show()

+-----+------+---+---+
|color| fruit| v1| v2|
+-----+------+---+---+
|  red|banana|  1| 10|
| blue|banana|  2| 20|
|  red|carrot|  3| 30|
| blue| grape|  4| 40|
|  red|carrot|  5| 50|
|black|carrot|  6| 60|
|  red|banana|  7| 70|
|  red| grape|  8| 80|
+-----+------+---+---+



In [24]:
df.groupBy('color').avg().show()

+-----+-------+-------+
|color|avg(v1)|avg(v2)|
+-----+-------+-------+
|  red|    4.8|   48.0|
| blue|    3.0|   30.0|
|black|    6.0|   60.0|
+-----+-------+-------+



In [25]:
def minus_mean(pandas_df):
    return pandas_df.assign(v1=pandas_df.v1 - pandas_df.v1.mean())

df.groupBy('color').applyInPandas(minus_mean, schema=df.schema).show()

+-----+------+---+---+
|color| fruit| v1| v2|
+-----+------+---+---+
|black|carrot|  0| 60|
| blue|banana| -1| 20|
| blue| grape|  1| 40|
|  red|banana| -3| 10|
|  red|carrot| -1| 30|
|  red|carrot|  0| 50|
|  red|banana|  2| 70|
|  red| grape|  3| 80|
+-----+------+---+---+





In [26]:
# grouping:

df1 = spark.createDataFrame(
    [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
    ('time', 'id', 'v1'))

df2 = spark.createDataFrame(
    [(20000101, 1, 'x'), (20000101, 2, 'y')],
    ('time', 'id', 'v2'))

In [27]:
def asof_join(left, right):
    return pd.merge_asof(left, right, on='time', by='id')

df1.groupby('id').cogroup(df2.groupby('id')).applyInPandas(
    asof_join, schema='time int, id int, v1 double, v2 string'
).show()

+--------+---+---+---+
|    time| id| v1| v2|
+--------+---+---+---+
|20000101|  1|1.0|  x|
|20000102|  1|3.0|  x|
|20000101|  2|2.0|  y|
|20000102|  2|4.0|  y|
+--------+---+---+---+





### I/O file formats

In [28]:
import os


if not os.path.exists('foo.csv'):
    df.write.csv('foo.csv', header = True)
    spark.read.csv('foo.csv', header = True).show()

In [29]:
if not os.path.exists('var.parquet'):
    df.write.parquet('var.parquet')
    spark.read.parquet('var.parquet').show()

In [30]:
if not os.path.exists('loo.orc'):
    df.write.orc('loo.orc')
    spark.read.orc('loo.orc').show()

### Working with SQL

DataFrame and Spark SQL share the same execution engine so they can be interchangeably used seamlessly.

In [31]:
df.createOrReplaceTempView('table1')
spark.sql('SELECT COUNT(*) FROM table1').show()

+--------+
|count(1)|
+--------+
|       8|
+--------+



In [32]:
@pandas_udf('integer')
def add_two(series: pd.Series) -> pd.Series:
    return series + 2

spark.udf.register('add_two', add_two)
spark.sql('SELECT add_two(v1) FROM table1').show()

+-----------+
|add_two(v1)|
+-----------+
|          3|
|          4|
|          5|
|          6|
|          7|
|          8|
|          9|
|         10|
+-----------+





In [33]:
from pyspark.sql.functions import expr

df.selectExpr('add_two(v1)').show()

+-----------+
|add_two(v1)|
+-----------+
|          3|
|          4|
|          5|
|          6|
|          7|
|          8|
|          9|
|         10|
+-----------+





In [34]:
df.select(expr('COUNT(*)') > 1).show()

+--------------+
|(count(1) > 1)|
+--------------+
|          true|
+--------------+



In [37]:
import numpy as np
import pyspark.pandas as ps
from pyspark.sql import SparkSession

s = pd.Series([3, 2, 1, 6, np.nan, 3, 67])
s



0     3.0
1     2.0
2     1.0
3     6.0
4     NaN
5     3.0
6    67.0
dtype: float64

Creating pandas-on-Spark DataFrame

In [41]:
psdf = ps.DataFrame({
    'a': [1, 2, 3, 4, 5, 5, 6, 7, 8],
    'b': [10, 20, 30, 40, 50, 50, 60, 70, 80],
    'c': ['one', 'two', 'three', 'four', 'five', 'five', 'six', 'seven', 'eight']},
    index=[100, 200, 300, 400, 500, 550, 600, 700, 800])
psdf

Unnamed: 0,a,b,c
100,1,10,one
200,2,20,two
300,3,30,three
400,4,40,four
500,5,50,five
550,5,50,five
600,6,60,six
700,7,70,seven
800,8,80,eight


In [42]:
dates = pd.date_range('20240101', periods=9)
dates

DatetimeIndex(['2024-01-01', '2024-01-02', '2024-01-03', '2024-01-04',
               '2024-01-05', '2024-01-06', '2024-01-07', '2024-01-08',
               '2024-01-09'],
              dtype='datetime64[ns]', freq='D')

In [43]:
pdf = pd.DataFrame(np.random.randn(9, 4), index=dates, columns=list('ABCD'))
pdf

Unnamed: 0,A,B,C,D
2024-01-01,-1.006082,-0.099414,-1.306332,-1.038
2024-01-02,-0.679112,1.890224,-0.371296,0.279985
2024-01-03,-0.34373,0.728058,0.739393,1.418214
2024-01-04,-2.048379,1.481702,1.02011,0.936646
2024-01-05,1.445954,0.474426,0.315122,0.98672
2024-01-06,-0.965397,-0.524129,1.303692,0.311385
2024-01-07,1.248959,-0.142986,-0.380655,-1.089089
2024-01-08,-1.264443,-0.413234,0.420525,-0.429557
2024-01-09,-0.272346,-0.128528,-0.811002,0.409272


In [45]:
psdf = ps.from_pandas(pdf)
psdf

  if is_datetime64tz_dtype(s.dtype):
  if not is_datetime64tz_dtype(pser.dtype):
  if is_datetime64tz_dtype(s.dtype):


Unnamed: 0,A,B,C,D
2024-01-01,-1.006082,-0.099414,-1.306332,-1.038
2024-01-02,-0.679112,1.890224,-0.371296,0.279985
2024-01-03,-0.34373,0.728058,0.739393,1.418214
2024-01-04,-2.048379,1.481702,1.02011,0.936646
2024-01-05,1.445954,0.474426,0.315122,0.98672
2024-01-06,-0.965397,-0.524129,1.303692,0.311385
2024-01-07,1.248959,-0.142986,-0.380655,-1.089089
2024-01-08,-1.264443,-0.413234,0.420525,-0.429557
2024-01-09,-0.272346,-0.128528,-0.811002,0.409272


Another possible way:

In [47]:
spark = SparkSession.builder.getOrCreate()
sdf = spark.createDataFrame(pdf)
sdf.show()

  if should_localize and is_datetime64tz_dtype(s.dtype) and s.dt.tz is not None:


+--------------------+--------------------+-------------------+-------------------+
|                   A|                   B|                  C|                  D|
+--------------------+--------------------+-------------------+-------------------+
| -1.0060815475609166|-0.09941396399635437|-1.3063316525823137|-1.0379998427230053|
| -0.6791118718482954|  1.8902243870095792|-0.3712964643797403| 0.2799845957882774|
|-0.34373000560742517|  0.7280582897516432| 0.7393932777568358| 1.4182144211396155|
|   -2.04837851974698|  1.4817019768628275| 1.0201102597005305| 0.9366459731839303|
|  1.4459544578024373| 0.47442622733643597|0.31512217456026936| 0.9867197121485896|
| -0.9653968631866249| -0.5241291738241138|  1.303692373694545| 0.3113846627989792|
|  1.2489588802256697|-0.14298616966924363| -0.380655050917499| -1.089089300241409|
|  -1.264443338722268| -0.4132338408091108|0.42052519987257475|-0.4295565481727579|
|-0.27234624033473254|-0.12852814089932613|-0.8110017764116179| 0.4092716235

In [48]:
psdf = sdf.pandas_api()
psdf

Unnamed: 0,A,B,C,D
0,-1.006082,-0.099414,-1.306332,-1.038
1,-0.679112,1.890224,-0.371296,0.279985
2,-0.34373,0.728058,0.739393,1.418214
3,-2.048379,1.481702,1.02011,0.936646
4,1.445954,0.474426,0.315122,0.98672
5,-0.965397,-0.524129,1.303692,0.311385
6,1.248959,-0.142986,-0.380655,-1.089089
7,-1.264443,-0.413234,0.420525,-0.429557
8,-0.272346,-0.128528,-0.811002,0.409272


### Missing Data

In [50]:
pdf1 = pdf.reindex(index=dates[:4], columns=list(pdf.columns) + ['E'])
pdf1.loc[dates[0]:dates[1], 'E'] = 7
pdf1

Unnamed: 0,A,B,C,D,E
2024-01-01,-1.006082,-0.099414,-1.306332,-1.038,7.0
2024-01-02,-0.679112,1.890224,-0.371296,0.279985,7.0
2024-01-03,-0.34373,0.728058,0.739393,1.418214,
2024-01-04,-2.048379,1.481702,1.02011,0.936646,


In [51]:
psdf1 = ps.from_pandas(pdf1)
psdf1

  if is_datetime64tz_dtype(s.dtype):
  if not is_datetime64tz_dtype(pser.dtype):
  if is_datetime64tz_dtype(s.dtype):


Unnamed: 0,A,B,C,D,E
2024-01-01,-1.006082,-0.099414,-1.306332,-1.038,7.0
2024-01-02,-0.679112,1.890224,-0.371296,0.279985,7.0
2024-01-03,-0.34373,0.728058,0.739393,1.418214,
2024-01-04,-2.048379,1.481702,1.02011,0.936646,


In [52]:
psdf1.dropna(how='any')

  if not is_datetime64tz_dtype(pser.dtype):
  if is_datetime64tz_dtype(s.dtype):


Unnamed: 0,A,B,C,D,E
2024-01-01,-1.006082,-0.099414,-1.306332,-1.038,7.0
2024-01-02,-0.679112,1.890224,-0.371296,0.279985,7.0


In [53]:
psdf1.fillna(value=0)

  if not is_datetime64tz_dtype(pser.dtype):
  if is_datetime64tz_dtype(s.dtype):


Unnamed: 0,A,B,C,D,E
2024-01-01,-1.006082,-0.099414,-1.306332,-1.038,7.0
2024-01-02,-0.679112,1.890224,-0.371296,0.279985,7.0
2024-01-03,-0.34373,0.728058,0.739393,1.418214,0.0
2024-01-04,-2.048379,1.481702,1.02011,0.936646,0.0


In [54]:
psdf1.mean()



CodeCache: size=131072Kb used=42591Kb max_used=42613Kb free=88480Kb
 bounds [0x000000010c1d8000, 0x000000010ebb8000, 0x00000001141d8000]
 total_blobs=14984 nmethods=14020 adapters=874
 compilation: disabled (not enough contiguous free space left)


A   -1.019325
B    1.000143
C    0.020469
D    0.399211
E    7.000000
dtype: float64

Spark Configurations

In [56]:
import warnings

prev = spark.conf.get('spark.sql.execution.arrow.pyspark.enabled')
ps.set_option('compute.default_index_type', 'distributed')

warnings.filterwarnings('ignore')

In [59]:
spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True)
%timeit ps.range(300000).to_pandas()

60.3 ms ± 5.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [60]:
spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', False)
%timeit ps.range(300000).to_pandas()

589 ms ± 65.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [61]:
ps.reset_option('compute.default_index_type')
spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', prev)  # returning to previous value

### Grouping PSDF

In [62]:
psdf = ps.DataFrame({'A': ['foo', 'bar', 'foo', 'bar',
                          'foo', 'bar', 'foo', 'foo'],
                    'B': ['one', 'one', 'two', 'three',
                          'two', 'two', 'one', 'three'],
                    'C': np.random.randn(8),
                    'D': np.random.randn(8)})

In [67]:
type(psdf)

pyspark.pandas.frame.DataFrame

In [65]:
psdf.groupby('A').sum()

Unnamed: 0_level_0,C,D
A,Unnamed: 1_level_1,Unnamed: 2_level_1
foo,2.583015,2.679566
bar,2.038361,0.798493


In [68]:
psdf.groupby(['A', 'B']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,C,D
A,B,Unnamed: 2_level_1,Unnamed: 3_level_1
foo,one,0.444703,0.928468
bar,one,0.863376,-0.203181
foo,two,0.234057,1.301747
bar,three,1.010265,-0.59966
bar,two,0.16472,1.601334
foo,three,1.225494,-1.780864
