In [22]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *


spark = SparkSession.builder.\
            master('local[*]').\
            config("spark.driver.memory", "5g").\
            appName('project').\
            config("spark.jars", "postgresql-42.7.0.jar").\
            getOrCreate()

In [23]:
# import pandas as pd

In [24]:
import pyarrow.parquet as pq

In [25]:
# pq.ParquetFile("data/yellowData/yellow_tripdata_2023-05.parquet").schema
# int64 - long
# int32 - int

In [26]:
# pq.ParquetFile("data/greenData/green_tripdata_2023-05.parquet").schema

In [27]:
yellow_df_schema = StructType([
    StructField('VendorId', LongType()),
    StructField('tpep_pickup_datetime', TimestampType()),
    StructField('tpep_dropoff_datetime', TimestampType()),
    StructField('passenger_count', LongType()),
    StructField('trip_distance',DoubleType()),
    StructField('RatecodeID',LongType()),
    StructField('store_and_fwd_flag',StringType()),
    StructField('PULocationID', IntegerType()),
    StructField('DOLocationID', IntegerType()),
    StructField('payment_type', LongType()),
    StructField('fare_amount',DoubleType()),
    StructField('extra',DoubleType()),
    StructField('mta_tax',DoubleType()),
    StructField('tip_amount',DoubleType()),
    StructField('tolls_amount',DoubleType()),
    StructField('improvement_surcharge',DoubleType()),
    StructField('total_amount',DoubleType()),
    StructField('congestion_surcharge',DoubleType()),
    StructField('airport_fee',DoubleType()),
])

green_df_schema = StructType([
    StructField('VendorId', IntegerType()),
    StructField('lpep_pickup_datetime', TimestampType()),
    StructField('lpep_dropoff_datetime', TimestampType()),
    StructField('store_and_fwd_flag',StringType()),
    StructField('RatecodeID',LongType()),
    StructField('PULocationID', IntegerType()),
    StructField('DOLocationID', IntegerType()),
    StructField('passenger_count', LongType()),
    StructField('trip_distance',DoubleType()),
    StructField('fare_amount',DoubleType()),
    StructField('extra',DoubleType()),
    StructField('mta_tax',DoubleType()),
    StructField('tip_amount',DoubleType()),
    StructField('tolls_amount',DoubleType()),
    StructField('ehail_fee', DoubleType()),
    StructField('improvement_surcharge',DoubleType()),
    StructField('total_amount',DoubleType()),
    StructField('payment_type', LongType()),
    StructField('trip_type', LongType()),
    StructField('congestion_surcharge',DoubleType()),
])

In [28]:
# yellow_df = spark.read.schema(yellow_df_schema).parquet("data/yellowData/*.parquet")
# green_df = spark.read.schema(green_df_schema).parquet("data/greenData/*.parquet")

In [29]:
import os

In [147]:
def filenames_func(directory_path):
    
    files = [i for i in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, i))]
    files_filtered = [os.path.join(directory_path,f) for f in files if not f.startswith('.')]
    
    return files_filtered

In [148]:
def create_df(file_paths):
    new_df = spark.read.parquet(file_paths[0])
    for i in file_paths[1:5]:
        new_df = new_df.\
                    union(spark.read.\
                    parquet(i))
    
    return new_df

In [149]:
yellow_df = create_df(filenames_func('./data/yellowData/'))
# green_df = create_df(filenames_func('./data/greenData/'))

In [150]:
print('number of records in yellow taxi data : ', yellow_df.count())
# print('number of records in green taxi data : ', green_df.count())

number of records in yellow taxi data :  15840450


In [151]:
yellow_df.select('VendorID').distinct().show()



+--------+
|VendorID|
+--------+
|       1|
|       6|
|       2|
+--------+



                                                                                

In [152]:
yellow_df.dtypes

[('VendorID', 'int'),
 ('tpep_pickup_datetime', 'timestamp_ntz'),
 ('tpep_dropoff_datetime', 'timestamp_ntz'),
 ('passenger_count', 'bigint'),
 ('trip_distance', 'double'),
 ('RatecodeID', 'bigint'),
 ('store_and_fwd_flag', 'string'),
 ('PULocationID', 'int'),
 ('DOLocationID', 'int'),
 ('payment_type', 'bigint'),
 ('fare_amount', 'double'),
 ('extra', 'double'),
 ('mta_tax', 'double'),
 ('tip_amount', 'double'),
 ('tolls_amount', 'double'),
 ('improvement_surcharge', 'double'),
 ('total_amount', 'double'),
 ('congestion_surcharge', 'double'),
 ('Airport_fee', 'double')]

In [153]:
temp_list = ['VendorID',
             'passenger_count',
             'RateCodeID',
             'store_and_fwd_flag',
             'payment_type',
             'Airport_fee']

In [154]:
for i in temp_list:
    distinct_values_list = yellow_df.select(i).distinct().collect()
    print(i, len(distinct_values_list), [i[0] for i in distinct_values_list])

VendorID 3 [1, 6, 2]


                                                                                

passenger_count 11 [0, 7, 6, 9, 5, 1, 3, 8, 2, 4, None]
RateCodeID 8 [6, 5, 1, 3, 2, 4, 99, None]




store_and_fwd_flag 3 ['Y', 'N', None]
payment_type 6 [0, 5, 1, 3, 2, 4]
Airport_fee 8 [0.0, -1.75, 1.7, None, 1.75, 1.25, 1.0, -1.25]


In [155]:
# lower case all column names
for col in yellow_df.columns:
    yellow_df = yellow_df.withColumnRenamed(col, col.lower())

In [156]:
yellow_df.dtypes

[('vendorid', 'int'),
 ('tpep_pickup_datetime', 'timestamp_ntz'),
 ('tpep_dropoff_datetime', 'timestamp_ntz'),
 ('passenger_count', 'bigint'),
 ('trip_distance', 'double'),
 ('ratecodeid', 'bigint'),
 ('store_and_fwd_flag', 'string'),
 ('pulocationid', 'int'),
 ('dolocationid', 'int'),
 ('payment_type', 'bigint'),
 ('fare_amount', 'double'),
 ('extra', 'double'),
 ('mta_tax', 'double'),
 ('tip_amount', 'double'),
 ('tolls_amount', 'double'),
 ('improvement_surcharge', 'double'),
 ('total_amount', 'double'),
 ('congestion_surcharge', 'double'),
 ('airport_fee', 'double')]

In [163]:
spark.read.parquet('datetimedim.parquet').count()

10181689

## Vendor ID

In [40]:
from pyspark.sql.functions import col
# VendorID
# cast to int
yellow_df = yellow_df.withColumn('vendorid', col('vendorid').cast(IntegerType()))
# can only have value 1 or 2 acc to data dictionary, drop other rows
yellow_df = yellow_df.filter(col('vendorid').isin(1,2))

In [41]:
yellow_df.count()

15837138

In [42]:
yellow_df.select('vendorid').distinct().show()



+--------+
|vendorid|
+--------+
|       1|
|       2|
+--------+



                                                                                

## Passenger Count

In [43]:
yellow_df.select('passenger_count').groupBy('passenger_count').count().show()



+---------------+--------+
|passenger_count|   count|
+---------------+--------+
|              0|  256536|
|              7|      31|
|           NULL|  462033|
|              6|  132533|
|              9|      27|
|              5|  194646|
|              1|11473702|
|              3|  612233|
|              8|     114|
|              2| 2350579|
|              4|  354704|
+---------------+--------+





In [44]:
#Passenger Count
# cast to int
yellow_df = yellow_df.withColumn('passenger_count', col('passenger_count').\
                                 cast(IntegerType()))
#remove None records
yellow_df = yellow_df.dropna(how = 'all', subset = ['passenger_count'])
#remove records with 0 as passenger count
yellow_df = yellow_df.filter(col('passenger_count') > 0)

In [45]:
yellow_df.select('passenger_count').groupBy('passenger_count').count().show()



+---------------+--------+
|passenger_count|   count|
+---------------+--------+
|              1|11473702|
|              6|  132533|
|              3|  612233|
|              5|  194646|
|              9|      27|
|              4|  354704|
|              8|     114|
|              7|      31|
|              2| 2350579|
+---------------+--------+



                                                                                

In [46]:
yellow_df.count()

                                                                                

15118569

## Trip Distance

In [47]:
yellow_df.filter(col('trip_distance') == 0).count()

                                                                                

193152

In [48]:
# cast to float
yellow_df = yellow_df.withColumn('trip_distance', col('trip_distance').cast(FloatType()))
#drop rows with trip distance = 0
yellow_df = yellow_df.filter((col('trip_distance') != 0) & (col('trip_distance') > 0))

In [49]:
yellow_df.count()

                                                                                

14925417

## Location ID

In [50]:
from pyspark.sql.functions import isnull

In [51]:
help(isnull)

Help on function isnull in module pyspark.sql.functions:

isnull(col: 'ColumnOrName') -> pyspark.sql.column.Column
    An expression that returns true if the column is null.
    
    .. versionadded:: 1.6.0
    
    .. versionchanged:: 3.4.0
        Supports Spark Connect.
    
    Parameters
    ----------
    col : :class:`~pyspark.sql.Column` or str
        target column to compute on.
    
    Returns
    -------
    :class:`~pyspark.sql.Column`
        True if value is null and False otherwise.
    
    Examples
    --------
    >>> df = spark.createDataFrame([(1, None), (None, 2)], ("a", "b"))
    >>> df.select("a", "b", isnull("a").alias("r1"), isnull(df.b).alias("r2")).show()
    +----+----+-----+-----+
    |   a|   b|   r1|   r2|
    +----+----+-----+-----+
    |   1|NULL|false| true|
    |NULL|   2| true|false|
    +----+----+-----+-----+



In [52]:
yellow_df.\
    select('pulocationid').\
    filter( (col('pulocationid') <1) | (isnull(col('pulocationid'))) ).count()

                                                                                

0

In [53]:
yellow_df.\
    select('dolocationid').\
    filter( (col('dolocationid') <1) | (isnull(col('dolocationid'))) ).count()

                                                                                

0

In [54]:
# cast location id to int
yellow_df = yellow_df.\
                withColumn('pulocationid', col('pulocationid').cast(IntegerType())).\
                withColumn('dolocationid', col('dolocationid').cast(IntegerType()))

In [55]:
yellow_df.dtypes

[('vendorid', 'int'),
 ('tpep_pickup_datetime', 'timestamp_ntz'),
 ('tpep_dropoff_datetime', 'timestamp_ntz'),
 ('passenger_count', 'int'),
 ('trip_distance', 'float'),
 ('ratecodeid', 'bigint'),
 ('store_and_fwd_flag', 'string'),
 ('pulocationid', 'int'),
 ('dolocationid', 'int'),
 ('payment_type', 'bigint'),
 ('fare_amount', 'double'),
 ('extra', 'double'),
 ('mta_tax', 'double'),
 ('tip_amount', 'double'),
 ('tolls_amount', 'double'),
 ('improvement_surcharge', 'double'),
 ('total_amount', 'double'),
 ('congestion_surcharge', 'double'),
 ('airport_fee', 'double')]

## Rate Code ID

In [56]:
## cast ratecode id to int
yellow_df = yellow_df.\
                withColumn('ratecodeid', col('ratecodeid').cast(IntegerType()))


In [57]:
## drop rows with invalid ratecode id
yellow_df = yellow_df.\
                filter(col('ratecodeid').isin(1,2,3,4,5,6))

## Store & fwd flag

In [58]:
from pyspark.sql.functions import lower

In [59]:
yellow_df.select('store_and_fwd_flag').distinct().show()



+------------------+
|store_and_fwd_flag|
+------------------+
|                 Y|
|                 N|
+------------------+



                                                                                

In [60]:
yellow_df.select('store_and_fwd_flag').groupBy('store_and_fwd_flag').count().show()



+------------------+--------+
|store_and_fwd_flag|   count|
+------------------+--------+
|                 Y|   85369|
|                 N|14765088|
+------------------+--------+



                                                                                

In [61]:
# convert values to lowercase
yellow_df = yellow_df.withColumn('store_and_fwd_flag', lower(col('store_and_fwd_flag')))

In [62]:
yellow_df.select('store_and_fwd_flag').groupBy('store_and_fwd_flag').count().show()



+------------------+--------+
|store_and_fwd_flag|   count|
+------------------+--------+
|                 n|14765088|
|                 y|   85369|
+------------------+--------+



                                                                                

## Payment Type

In [63]:
yellow_df.\
    select('payment_type').\
    groupBy('payment_type').\
    count().show()



+------------+--------+
|payment_type|   count|
+------------+--------+
|           1|11919747|
|           3|   72289|
|           2| 2672504|
|           4|  185917|
+------------+--------+



                                                                                

## Fare amount

In [64]:
yellow_df.\
    select('fare_amount').\
    filter(col('fare_amount') < 0).\
    count()

                                                                                

136941

In [65]:
from pyspark.sql.functions import when
## adding a column saying whether the fare amount is +ve or -ve
yellow_df = yellow_df.\
                withColumn('errordata',when(col('fare_amount') < 0, 'y').\
                otherwise('n'))

## casting the column to double

yellow_df = yellow_df.\
                withColumn('fare_amount', col('fare_amount').cast(DoubleType()))

## Extra

In [66]:
from pyspark.sql.functions import max

yellow_df.select('extra').groupBy('extra').count().show(1000)



+-----+-------+
|extra|  count|
+-----+-------+
|  0.0|5611620|
|  3.5| 852177|
|10.25|  18756|
| -5.0|   5112|
| 12.5|      4|
|  4.5|     25|
| 0.05|      2|
| -1.0|  42563|
| 9.25|  66190|
|  0.7|     17|
|  2.5|3692502|
| 2.15|      3|
|  1.0|2988270|
| 2.45|     64|
|  2.7|      1|
|  2.8|      4|
| -6.0|   1934|
| 0.04|      5|
| 7.45|      4|
|  7.5| 134700|
| 0.75|     52|
|  3.8|      1|
| 11.0|      6|
| 2.75|  26202|
|  8.5|   3068|
| 2.25|    804|
|  3.2|    254|
|14.25|     22|
| -2.5|  22110|
|12.75|     22|
| 0.03|      3|
| 6.75|  25228|
|  2.0|      2|
|  1.5|     65|
| 0.25|     12|
|  8.2|      1|
| 4.25|  71144|
| 3.25|     85|
| 0.01|      3|
| 10.0|   4459|
| 5.75|      6|
| 0.11|      2|
| 0.02|      9|
|  6.0| 138174|
| 1.75|  35158|
| 7.75|  16262|
| -7.5|   1191|
|  5.0|1070836|
|11.75|  13798|
|-4.25|      1|
|-0.25|      1|
| -3.5|      4|
| 5.25|    314|
| 0.06|      1|
|  7.0|      7|
| 1.25|   1112|
|  1.7|      1|
|-0.75|      6|
| 4.75|     13|
|  3.1| 

                                                                                

In [67]:
yellow_df.select('mta_tax').groupBy('mta_tax').count().show(1000)



+-------+--------+
|mta_tax|   count|
+-------+--------+
|    0.0|  101294|
|   0.05|     232|
|    4.0|      14|
|    0.5|14613701|
|  -0.05|       6|
|   -0.5|  135204|
|   5.75|       1|
|   1.05|       1|
|    3.5|       1|
|    0.8|       1|
|   1.53|       1|
|   3.25|       1|
+-------+--------+





In [68]:
# yellow_df.filter()

In [69]:
yellow_df.columns

['vendorid',
 'tpep_pickup_datetime',
 'tpep_dropoff_datetime',
 'passenger_count',
 'trip_distance',
 'ratecodeid',
 'store_and_fwd_flag',
 'pulocationid',
 'dolocationid',
 'payment_type',
 'fare_amount',
 'extra',
 'mta_tax',
 'tip_amount',
 'tolls_amount',
 'improvement_surcharge',
 'total_amount',
 'congestion_surcharge',
 'airport_fee',
 'errordata']

In [70]:
yellow_df.\
    filter(
        (col('extra') >= 0) &
        (col('mta_tax') >= 0) &
        (col('tip_amount') >= 0) &
        (col('tolls_amount') >= 0) &
        (col('improvement_surcharge') >= 0) &
        (col('total_amount') >= 0) &
        (col('congestion_surcharge') >= 0) &
        (col('airport_fee') >= 0) 
    ).count()

                                                                                

14712990

In [71]:
yellow_df.count()

                                                                                

14850457

In [72]:
df = spark.createDataFrame([('X',)], schema = 'name string')

In [73]:
cond = col('name') == 'X'

In [74]:
df.withColumn('test', when(cond, 1).otherwise(0)).show()

                                                                                

+----+----+
|name|test|
+----+----+
|   X|   1|
+----+----+



In [75]:
help(spark.createDataFrame)

Help on method createDataFrame in module pyspark.sql.session:

createDataFrame(data: Union[pyspark.rdd.RDD[Any], Iterable[Any], ForwardRef('PandasDataFrameLike'), ForwardRef('ArrayLike')], schema: Union[pyspark.sql.types.AtomicType, pyspark.sql.types.StructType, str, NoneType] = None, samplingRatio: Optional[float] = None, verifySchema: bool = True) -> pyspark.sql.dataframe.DataFrame method of pyspark.sql.session.SparkSession instance
    Creates a :class:`DataFrame` from an :class:`RDD`, a list, a :class:`pandas.DataFrame`
    or a :class:`numpy.ndarray`.
    
    .. versionadded:: 2.0.0
    
    .. versionchanged:: 3.4.0
        Supports Spark Connect.
    
    Parameters
    ----------
    data : :class:`RDD` or iterable
        an RDD of any kind of SQL data representation (:class:`Row`,
        :class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`,
        :class:`pandas.DataFrame` or :class:`numpy.ndarray`.
    schema : :class:`pyspark.sql.types.DataType`, str or list, op

In [76]:
non_negative_condition =(
                            (col('fare_amount') >= 0)& 
                            (col('extra') >= 0) &
                            (col('mta_tax') >= 0) &
                            (col('tip_amount') >= 0) &
                            (col('tolls_amount') >= 0) &
                            (col('improvement_surcharge') >= 0) &
                            (col('total_amount') >= 0) &
                            (col('congestion_surcharge') >= 0) &
                            (col('airport_fee') >= 0) 
    )

yellow_df = yellow_df.\
        withColumn('errordata',when(non_negative_condition, 'n').\
        otherwise('y'))

In [77]:
yellow_df.count()

                                                                                

14850457

In [78]:
yellow_df.groupby('errordata').count().show()



+---------+--------+
|errordata|   count|
+---------+--------+
|        n|14712990|
|        y|  137467|
+---------+--------+



                                                                                

In [79]:
# from pyspark.sql.functions import monotonically_increasing_id 

# yellow_df.\
#     select('vendorid').\
#     distinct().\
#     withColumn('vendor_key', monotonically_increasing_id()).\
#     withColumn('vendor_name', )
# show()

In [80]:
from pyspark.sql.functions import monotonically_increasing_id 

help(monotonically_increasing_id)

Help on function monotonically_increasing_id in module pyspark.sql.functions:

monotonically_increasing_id() -> pyspark.sql.column.Column
    A column that generates monotonically increasing 64-bit integers.
    
    The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive.
    The current implementation puts the partition ID in the upper 31 bits, and the record number
    within each partition in the lower 33 bits. The assumption is that the data frame has
    less than 1 billion partitions, and each partition has less than 8 billion records.
    
    .. versionadded:: 1.6.0
    
    .. versionchanged:: 3.4.0
        Supports Spark Connect.
    
    Notes
    -----
    The function is non-deterministic because its result depends on partition IDs.
    
    As an example, consider a :class:`DataFrame` with two partitions, each with 3 records.
    This expression would return the following IDs:
    0, 1, 2, 8589934592 (1L << 33), 8589934593, 858993459

In [81]:
@udf
def vendor_name_udf(vendorid):
    vendor_dict = {
                    1 : 'Creative Mobile Technologies, LLC',
                   2 : 'VeriFone Inc.'
    }
    if vendorid is not None:
        return vendor_dict[vendorid]

vendor_DIM = yellow_df.\
    select('vendorid').\
    distinct().\
    withColumn('vendor_key', monotonically_increasing_id()).\
    withColumn('vendor_name', vendor_name_udf(col('vendorid')))

In [82]:
vendor_DIM.show()

[Stage 129:>                                                        (0 + 1) / 1]

+--------+----------+--------------------+
|vendorid|vendor_key|         vendor_name|
+--------+----------+--------------------+
|       1|         0|Creative Mobile T...|
|       2|         1|       VeriFone Inc.|
+--------+----------+--------------------+



                                                                                

In [83]:
yellow_df.columns

['vendorid',
 'tpep_pickup_datetime',
 'tpep_dropoff_datetime',
 'passenger_count',
 'trip_distance',
 'ratecodeid',
 'store_and_fwd_flag',
 'pulocationid',
 'dolocationid',
 'payment_type',
 'fare_amount',
 'extra',
 'mta_tax',
 'tip_amount',
 'tolls_amount',
 'improvement_surcharge',
 'total_amount',
 'congestion_surcharge',
 'airport_fee',
 'errordata']

In [84]:
pu_times = yellow_df.\
    select(col('tpep_pickup_datetime').alias('timestamp')).\
    dropDuplicates()
pu_times.cache()

DataFrame[timestamp: timestamp_ntz]

In [85]:
do_times = yellow_df.\
    select(col('tpep_dropoff_datetime').alias('timestamp')).\
    dropDuplicates()
do_times.cache()

DataFrame[timestamp: timestamp_ntz]

In [86]:
datetime_DIM = pu_times.\
    union(do_times).\
    dropDuplicates().\
    withColumn('datetime_key', monotonically_increasing_id())

In [87]:
do_times.count()

                                                                                

7901697

In [88]:
pu_times.count()

                                                                                

7904168

In [89]:
datetime_DIM.show()

[Stage 152:>                                                        (0 + 1) / 1]

+-------------------+------------+
|          timestamp|datetime_key|
+-------------------+------------+
|2023-06-01 00:48:24|           0|
|2023-06-01 00:38:30|           1|
|2023-06-01 00:02:52|           2|
|2023-06-01 00:35:38|           3|
|2023-06-01 00:33:21|           4|
|2023-06-01 00:23:20|           5|
|2023-06-01 00:25:58|           6|
|2023-06-01 00:09:42|           7|
|2023-06-01 01:07:01|           8|
|2023-06-01 01:30:58|           9|
|2023-06-01 01:26:11|          10|
|2023-06-01 01:09:09|          11|
|2023-06-01 01:19:48|          12|
|2023-06-01 02:16:29|          13|
|2023-06-01 03:22:23|          14|
|2023-06-01 04:18:30|          15|
|2023-06-01 04:12:21|          16|
|2023-06-01 05:16:51|          17|
|2023-06-01 06:39:13|          18|
|2023-06-01 06:00:09|          19|
+-------------------+------------+
only showing top 20 rows



                                                                                

In [90]:
datetime_DIM.count()

                                                                                

10181689

In [91]:
yellow_df.columns

['vendorid',
 'tpep_pickup_datetime',
 'tpep_dropoff_datetime',
 'passenger_count',
 'trip_distance',
 'ratecodeid',
 'store_and_fwd_flag',
 'pulocationid',
 'dolocationid',
 'payment_type',
 'fare_amount',
 'extra',
 'mta_tax',
 'tip_amount',
 'tolls_amount',
 'improvement_surcharge',
 'total_amount',
 'congestion_surcharge',
 'airport_fee',
 'errordata']

In [92]:
pu_locations = yellow_df.\
    select(col('pulocationid').alias('locationid')).\
    dropDuplicates()

In [93]:
do_locations = yellow_df.\
    select(col('dolocationid').alias('locationid')).\
    dropDuplicates()

In [94]:
taxizone_DIM = pu_locations.\
    union(do_locations).\
    dropDuplicates().\
    withColumn('zone_key', monotonically_increasing_id())

In [95]:
taxizone_DIM.count()

                                                                                

263

In [96]:
@udf
def paymenttype_name_udf(payment_type):
    payment_type_dict = {
                    1 : 'Credit card',
                    2 : 'Cash',
                    3 : 'No charge',
                    4 : 'Dispute',
                    5 : 'Unknown',
                    6 : 'Voided trip'
    }
    if payment_type is not None:
        return payment_type_dict[payment_type]

paymenttype_DIM = yellow_df.\
    select('payment_type').\
    distinct().\
    withColumn('paymenttype_key', monotonically_increasing_id()).\
    withColumn('description', paymenttype_name_udf(col('payment_type')))

In [97]:
paymenttype_DIM.show()



+------------+---------------+-----------+
|payment_type|paymenttype_key|description|
+------------+---------------+-----------+
|           1|              0|Credit card|
|           3|              1|  No charge|
|           2|              2|       Cash|
|           4|              3|    Dispute|
+------------+---------------+-----------+



                                                                                

In [98]:
@udf
def ratecode_name_udf(ratecode):
    ratecode_dict = {
                    1 : 'Standard rate',
                    2 : 'JFK',
                    3 : 'Newark',
                    4 : 'Nassau or Westchester',
                    5 : 'Negotiated fare',
                    6 : 'Group ride'
    }
    if ratecode is not None:
        return ratecode_dict[ratecode]

ratecode_DIM = yellow_df.\
    select('ratecodeid').\
    distinct().\
    withColumn('ratecode_key', monotonically_increasing_id()).\
    withColumn('description', ratecode_name_udf(col('ratecodeid')))

In [99]:
ratecode_DIM.show()



+----------+------------+--------------------+
|ratecodeid|ratecode_key|         description|
+----------+------------+--------------------+
|         1|           0|       Standard rate|
|         6|           1|          Group ride|
|         3|           2|              Newark|
|         5|           3|     Negotiated fare|
|         4|           4|Nassau or Westche...|
|         2|           5|                 JFK|
+----------+------------+--------------------+



                                                                                

In [102]:
print(yellow_df.columns)
print(datetime_DIM.columns)

['vendorid', 'tpep_pickup_datetime', 'tpep_dropoff_datetime', 'passenger_count', 'trip_distance', 'ratecodeid', 'store_and_fwd_flag', 'pulocationid', 'dolocationid', 'payment_type', 'fare_amount', 'extra', 'mta_tax', 'tip_amount', 'tolls_amount', 'improvement_surcharge', 'total_amount', 'congestion_surcharge', 'airport_fee', 'errordata']
['timestamp', 'datetime_key']


In [123]:
## swapped pickup datetime with putime_key
fact_table = yellow_df.join(datetime_DIM, yellow_df.tpep_pickup_datetime == datetime_DIM.timestamp, 'inner').\
                drop('tpep_pickup_datetime','timestamp').\
                withColumnRenamed('datetime_key', 'putime_key')

In [125]:
## swapped dropoff datetime with dotime_key
fact_table = fact_table.join(datetime_DIM, fact_table.tpep_dropoff_datetime == datetime_DIM.timestamp, 'inner').\
                drop('tpep_dropoff_datetime','timestamp').\
                withColumnRenamed('datetime_key', 'dotime_key')

In [126]:
## swapped vendor id with vendor_key
fact_table = fact_table.join(vendor_DIM, 'vendorid', 'inner').\
                drop('vendorid','vendor_name')

In [129]:
## swapped ratecodeid with ratecode_key
fact_table = fact_table.join(ratecode_DIM, 'ratecodeid', 'inner').\
                drop('ratecodeid','description')

In [132]:
## pickup location id
fact_table = fact_table.join(taxizone_DIM, fact_table.pulocationid == taxizone_DIM.locationid, 'inner').\
                drop('pulocationid','locationid').\
                withColumnRenamed('zone_key', 'puzone_key')

In [133]:
## dropoff location id
fact_table = fact_table.join(taxizone_DIM, fact_table.dolocationid == taxizone_DIM.locationid, 'inner').\
                drop('dolocationid','locationid').\
                withColumnRenamed('zone_key', 'dozone_key')

In [136]:
## payment type
fact_table = fact_table.join(paymenttype_DIM, 'payment_type', 'inner').\
                drop('payment_type','description')

In [139]:
cols_in_order = ['vendor_key',
                 'putime_key',
                 'dotime_key',
                 'passenger_count',
                 'trip_distance',
                 'ratecode_key',
                 'store_and_fwd_flag',
                 'puzone_key',
                 'dozone_key',
                 'paymenttype_key',
                 'fare_amount',
                 'extra',
                 'mta_tax',
                 'tip_amount',
                 'tolls_amount',
                 'improvement_surcharge',
                 'total_amount',
                 'congestion_surcharge',
                 'airport_fee',
                 'errordata']

In [142]:
## ordering columns in the fact table
fact_table = fact_table[*cols_in_order]

In [145]:
fact_table.limit(100).count()

23/12/05 21:47:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/12/05 21:47:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/12/05 21:47:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/12/05 21:47:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/12/05 21:47:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/12/05 21:47:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/12/05 21:47:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/12/05 21:47:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
                                                                

100

In [None]:
# !pip install psycopg2-binary

In [181]:
datetime_DIM.count()

                                                                                

10181689