In [1]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, IntegerType, DateType, StructField, StringType, TimestampType

In [2]:
spark = SparkSession.builder \
    .master("local[*]") \
    .appName('test') \
    .getOrCreate()

22/02/24 09:55:03 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [3]:
!wget https://nyc-tlc.s3.amazonaws.com/trip+data/fhvhv_tripdata_2021-01.csv

--2022-02-22 22:07:42--  https://nyc-tlc.s3.amazonaws.com/trip+data/fhvhv_tripdata_2021-01.csv
Resolving nyc-tlc.s3.amazonaws.com (nyc-tlc.s3.amazonaws.com)... 52.217.141.233
Connecting to nyc-tlc.s3.amazonaws.com (nyc-tlc.s3.amazonaws.com)|52.217.141.233|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 752335705 (717M) [text/csv]
Saving to: ‘fhvhv_tripdata_2021-01.csv.1’


2022-02-22 22:08:04 (33.8 MB/s) - ‘fhvhv_tripdata_2021-01.csv.1’ saved [752335705/752335705]



Note: inferSchema=True, header=True will inferr types but dates will still be string. 
Passing the schema when reading the csv works

In [30]:
schema = StructType([
    StructField("hvfhs_license_num", StringType(), True),
    StructField("dispatching_base_num", StringType(), True),
    StructField("pickup_datetime", TimestampType(), True),
    StructField("dropoff_datetime", TimestampType(), True),
    StructField("PULocationID", IntegerType(), True),
    StructField("DOLocationID", IntegerType(), True),
    StructField("SR_Flag", StringType(), True),
])
df = spark.read \
    .options(header=True) \
    .schema(schema) \
    .csv('fhvhv_tripdata_2021-01.csv')

In [31]:
df.schema

StructType(List(StructField(hvfhs_license_num,StringType,true),StructField(dispatching_base_num,StringType,true),StructField(pickup_datetime,TimestampType,true),StructField(dropoff_datetime,TimestampType,true),StructField(PULocationID,IntegerType,true),StructField(DOLocationID,IntegerType,true),StructField(SR_Flag,StringType,true)))

We can create spark dataframe from pandas dataframe. However, some conversion is done. Let's check final schema after conversion. It is better to set the schema as previously done, using integer (4 bytes) avoiding the automatic float64 or long (8 bytes in both cases)

In [6]:
import pandas as pd

In [7]:
df_pandas = pd.read_csv('fhvhv_tripdata_2021-01.csv')

In [8]:
df_pandas.dtypes

hvfhs_license_num        object
dispatching_base_num     object
pickup_datetime          object
dropoff_datetime         object
PULocationID              int64
DOLocationID              int64
SR_Flag                 float64
dtype: object

In [32]:
spark.createDataFrame(df_pandas).schema

NameError: name 'df_pandas' is not defined

Note that this conversion takes too long. It is better to read csv with spark and set the schema

In [33]:
df = df.repartition(24)

In [34]:
df.write.parquet('fhvhv/2021/01', mode='overwrite')

                                                                                

Writting to parket format will use several executors in the cluster for parallel processing the partitions. repartition is a logic command, nothing is processed until a processing command is performed, such as write.parquet

In [35]:
df = spark.read.parquet('fhvhv/2021/01/')

In [36]:
df.printSchema()

root
 |-- hvfhs_license_num: string (nullable = true)
 |-- dispatching_base_num: string (nullable = true)
 |-- pickup_datetime: timestamp (nullable = true)
 |-- dropoff_datetime: timestamp (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- SR_Flag: string (nullable = true)



Select and filter, joins, group by... are transformations, not executed, these are lazy. Show, take, head, write... are actions, so previous related transformations and action are executed together

In [39]:
from pyspark.sql import functions as F

In [40]:
def crazy_stuff(base_num):
    num = int(base_num[1:])
    if num % 7 == 0:
        return f's/{num:03x}'
    else:
        return f'e/{num:03x}'    

In [42]:
crazy_stuff_udf = F.udf(crazy_stuff, returnType=StringType())

Note: select returns dataframe. df.xxxx returnscolumn

In [44]:
df \
    .withColumn('pickup_date', F.to_date(df.pickup_datetime)) \
    .withColumn('dropoff_date', F.to_date(df.dropoff_datetime)) \
    .withColumn('base_id', crazy_stuff_udf(df.dispatching_base_num)) \
    .select('pickup_date', 'dropoff_date', 'PULocationID', 'DOLocationID', 'base_id') \
    .show()

[Stage 18:>                                                         (0 + 1) / 1]

+-----------+------------+------------+------------+-------+
|pickup_date|dropoff_date|PULocationID|DOLocationID|base_id|
+-----------+------------+------------+------------+-------+
| 2021-01-07|  2021-01-07|         142|         230|  e/9ce|
| 2021-01-01|  2021-01-01|         133|          91|  e/9ce|
| 2021-01-01|  2021-01-01|         147|         159|  e/acc|
| 2021-01-06|  2021-01-06|          79|         164|  e/b35|
| 2021-01-04|  2021-01-04|         174|          18|  s/b44|
| 2021-01-04|  2021-01-04|         201|         180|  e/b3b|
| 2021-01-04|  2021-01-04|         230|         142|  e/9ce|
| 2021-01-03|  2021-01-03|         132|          72|  e/b38|
| 2021-01-01|  2021-01-01|         188|          61|  s/af0|
| 2021-01-04|  2021-01-04|          97|         189|  e/9ce|
| 2021-01-01|  2021-01-01|         174|         235|  e/acc|
| 2021-01-05|  2021-01-05|          35|          76|  e/b37|
| 2021-01-06|  2021-01-06|          35|          39|  e/9ce|
| 2021-01-04|  2021-01-0

                                                                                

In [45]:
df.select('pickup_datetime', 'dropoff_datetime', 'PULocationID', 'DOLocationID', 'SR_Flag') \
.filter(df.hvfhs_license_num == 'HV0003') \
.take(5)

[Row(pickup_datetime=datetime.datetime(2021, 1, 1, 0, 23, 13), dropoff_datetime=datetime.datetime(2021, 1, 1, 0, 30, 35), PULocationID=147, DOLocationID=159, SR_Flag=None),
 Row(pickup_datetime=datetime.datetime(2021, 1, 6, 11, 43, 12), dropoff_datetime=datetime.datetime(2021, 1, 6, 11, 55, 7), PULocationID=79, DOLocationID=164, SR_Flag=None),
 Row(pickup_datetime=datetime.datetime(2021, 1, 4, 15, 35, 32), dropoff_datetime=datetime.datetime(2021, 1, 4, 15, 52, 2), PULocationID=174, DOLocationID=18, SR_Flag=None),
 Row(pickup_datetime=datetime.datetime(2021, 1, 4, 13, 42, 15), dropoff_datetime=datetime.datetime(2021, 1, 4, 14, 4, 57), PULocationID=201, DOLocationID=180, SR_Flag=None),
 Row(pickup_datetime=datetime.datetime(2021, 1, 3, 18, 42, 3), dropoff_datetime=datetime.datetime(2021, 1, 3, 19, 12, 22), PULocationID=132, DOLocationID=72, SR_Flag=None)]

Note: Take returns a list, not a dataframe