In [7]:
from pathlib import Path
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("NYC Taxi Analysis") \
    .master("local[*]") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "4g") \
    .config("spark.driver.bindAddress", "127.0.0.1") \
    .config("spark.ui.port", "4040") \
    .getOrCreate()


data_path = Path("../data/raw")
all_files = sorted(data_path.glob("yellow_tripdata_2016-*.parquet"))

train_files = [str(f) for f in all_files if f.name <= "yellow_tripdata_2016-05.parquet"]
eval_files  = [str(f) for f in all_files if f.name == "yellow_tripdata_2016-11.parquet"]
holdout_files = [str(f) for f in all_files if f.name == "yellow_tripdata_2016-12.parquet"]

train = spark.read.parquet(*train_files)
eval = spark.read.parquet(*eval_files) if eval_files else None
holdout = spark.read.parquet(*holdout_files) if holdout_files else None


print(f"train rows : {train.count()}")
print(f"eval rows : {eval.count()}")
print(f"holdout rows : {holdout.count()}")


train rows : 58244348
eval rows : 10102128
holdout rows : 10446697


In [8]:
train_pd = train.limit(10).toPandas()
train_pd.head(5)


Unnamed: 0,VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,airport_fee
0,1,2016-01-01 00:12:22,2016-01-01 00:29:14,1,3.2,1,N,48,262,1,14.0,0.5,0.5,3.06,0.0,0.3,18.36,,
1,1,2016-01-01 00:41:31,2016-01-01 00:55:10,2,1.0,1,N,162,48,2,9.5,0.5,0.5,0.0,0.0,0.3,10.8,,
2,1,2016-01-01 00:53:37,2016-01-01 00:59:57,1,0.9,1,N,246,90,2,6.0,0.5,0.5,0.0,0.0,0.3,7.3,,
3,1,2016-01-01 00:13:28,2016-01-01 00:18:07,1,0.8,1,N,170,162,2,5.0,0.5,0.5,0.0,0.0,0.3,6.3,,
4,1,2016-01-01 00:33:04,2016-01-01 00:47:14,1,1.8,1,N,161,140,2,11.0,0.5,0.5,0.0,0.0,0.3,12.3,,


In [9]:
eval_pandas = eval.limit(10).toPandas()
eval_pandas.head(10)

Unnamed: 0,VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,airport_fee
0,1,2016-11-01 00:02:29,2016-11-01 00:08:06,1,1.4,1,N,48,143,1,6.5,0.5,0.5,1.55,0.0,0.3,9.35,,
1,1,2016-11-01 00:04:14,2016-11-01 00:11:35,1,2.1,1,N,239,41,1,9.0,0.5,0.5,2.05,0.0,0.3,12.35,,
2,1,2016-11-01 00:12:38,2016-11-01 00:15:58,1,1.0,1,N,41,42,2,5.5,0.5,0.5,0.0,0.0,0.3,6.8,,
3,1,2016-11-01 00:17:48,2016-11-01 00:20:28,1,0.9,1,N,239,151,2,4.5,0.5,0.5,0.0,0.0,0.3,5.8,,
4,1,2016-11-01 00:28:55,2016-11-01 00:31:23,1,0.6,1,N,186,90,1,4.0,0.5,0.5,1.05,0.0,0.3,6.35,,
5,1,2016-11-01 00:20:25,2016-11-01 00:27:30,2,1.6,1,N,107,162,2,7.5,0.5,0.5,0.0,0.0,0.3,8.8,,
6,1,2016-11-01 00:46:55,2016-11-01 00:50:24,1,0.6,1,N,163,141,1,4.5,0.5,0.5,1.5,0.0,0.3,7.3,,
7,1,2016-11-01 00:08:04,2016-11-01 00:14:57,1,1.5,1,N,79,233,2,6.5,0.5,0.5,0.0,0.0,0.3,7.8,,
8,1,2016-11-01 00:18:25,2016-11-01 00:21:05,1,0.2,1,N,233,229,1,3.0,0.5,0.5,0.01,0.0,0.3,4.31,,
9,1,2016-11-01 00:03:33,2016-11-01 00:09:56,1,1.2,1,N,161,142,1,6.5,0.5,0.5,2.3,0.0,0.3,10.1,,


In [10]:
holdout_pandas = holdout.limit(10).toPandas()
holdout_pandas.head(10)

Unnamed: 0,VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,airport_fee
0,1,2016-12-01 00:26:26,2016-12-01 00:26:54,1,0.0,1,Y,145,145,2,2.5,0.5,0.5,0.0,0.0,0.3,3.8,,
1,1,2016-12-01 00:08:13,2016-12-01 00:24:20,1,4.2,1,N,262,226,2,15.5,0.5,0.5,0.0,0.0,0.3,16.8,,
2,1,2016-12-01 00:36:29,2016-12-01 00:40:16,1,1.1,1,N,238,75,2,5.5,0.5,0.5,0.0,0.0,0.3,6.8,,
3,1,2016-12-01 00:55:28,2016-12-01 01:01:04,1,1.2,1,N,237,230,2,6.5,0.5,0.5,0.0,0.0,0.3,7.8,,
4,2,2016-12-01 00:13:08,2016-12-01 00:29:21,1,1.48,1,N,142,161,1,11.0,0.5,0.5,3.08,0.0,0.3,15.38,,
5,2,2016-12-01 00:43:16,2016-12-01 00:53:22,1,2.25,1,N,142,262,1,9.0,0.5,0.5,2.58,0.0,0.3,12.88,,
6,1,2016-12-01 00:27:18,2016-12-01 00:42:36,1,3.4,1,N,234,236,2,13.5,0.5,0.5,0.0,0.0,0.3,14.8,,
7,1,2016-12-01 00:35:50,2016-12-01 00:38:40,1,0.7,1,N,237,237,1,4.5,0.5,0.5,1.15,0.0,0.3,6.95,,
8,1,2016-12-01 00:19:45,2016-12-01 00:28:24,1,2.2,1,N,162,107,2,9.0,0.5,0.5,0.0,0.0,0.3,10.3,,
9,1,2016-12-01 00:36:40,2016-12-01 00:44:36,1,1.2,1,N,114,79,1,7.0,0.5,0.5,0.0,0.0,0.3,8.3,,


In [11]:
train = train.repartition(4)  # 4 файла вместо сотен
train.write.parquet("../data/raw/train")
eval = eval.repartition(4)  # 4 файла вместо сотен
eval.write.parquet("../data/raw/eval")
holdout = holdout.repartition(4)  # 4 файла вместо сотен
holdout.write.parquet("../data/raw/holdout")
spark.stop()



AnalysisException: [PATH_ALREADY_EXISTS] Path file:/Users/anatolijperederij/PycharmProjects/nyc-taxi-ml-pipeline/data/raw/holdout already exists. Set mode as "overwrite" to overwrite the existing path. SQLSTATE: 42K04