In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# ————— TaxiDataProcessorDuckDB.py —————
import os
import duckdb
import pandas as pd
import glob
from matplotlib import pyplot as plt
import folium
from datetime import datetime, date, timedelta

class TaxiDataProcessorDuckDB:
    def __init__(self,
                 parquet_pattern: str,
                 zone_csv: str,
                 output_path: str = None,
                 output_format: str = 'parquet'):
        """
        :param parquet_pattern: glob path to input parquet files (e.g., '/path/to/yellow_tripdata_*.parquet')
        :param zone_csv: path to taxi_zone_with_coords.csv
        :param output_path: path or directory for writing processed data
        :param output_format: 'parquet' or 'csv'
        """
        self.parquet_pattern = parquet_pattern
        self.zone_csv = zone_csv
        self.output_path = output_path
        self.output_format = output_format
        # Initialize an in-memory DuckDB connection
        self.con = duckdb.connect(database=':memory:')
        # Register zone CSV as a DuckDB table
        self.con.execute(f"""
            CREATE TABLE zones AS
            SELECT CAST(LocationID AS INTEGER) AS zone_id,
                   CAST(lat AS DOUBLE) AS lat,
                   CAST(lon AS DOUBLE) AS lon
            FROM read_csv_auto('{self.zone_csv}')
        """)

    def load_and_clean(self):
        """
        Step 1–3: Read parquet files into DuckDB, filter invalid fares and missing zones.
        """
        # Read all parquet files matching pattern into a DuckDB table
        self.con.execute(f"""
            CREATE TABLE trips AS
            SELECT * FROM read_parquet('{self.parquet_pattern}')
        """)

        # Filter out invalid fares/amounts and null PULocationID/DOLocationID
        self.con.execute("""
            CREATE TABLE trips_clean AS
            SELECT *
            FROM trips
            WHERE fare_amount > 0
              AND total_amount >= 0
              AND extra >= 0
              AND PULocationID IS NOT NULL
              AND DOLocationID IS NOT NULL
        """)
        # Drop original trips to free memory
        self.con.execute("DROP TABLE trips;")

    def map_coordinates(self):
        """
        Step 4–5: Join zone coordinates to trips_clean and drop rows with missing coordinates.
        """
        self.con.execute("""
            CREATE TABLE trips_coord AS
            SELECT t.*,
                   p.lat AS pickup_lat, p.lon AS pickup_lon,
                   d.lat AS dropoff_lat, d.lon AS dropoff_lon
            FROM trips_clean t
            LEFT JOIN zones p ON t.PULocationID = p.zone_id
            LEFT JOIN zones d ON t.DOLocationID = d.zone_id
            WHERE p.lat IS NOT NULL
              AND p.lon IS NOT NULL
              AND d.lat IS NOT NULL
              AND d.lon IS NOT NULL
        """)
        self.con.execute("DROP TABLE trips_clean;")

    def process_datetime(self):
        """
        Step 6: Convert pickup datetime to DATE/TIMESTAMP, extract year/month/day.
        """
        # Create a new table with parsed datetime and extracted components
        self.con.execute("""
            CREATE TABLE trips_dt AS
            SELECT *,
                   CAST(strftime('%Y', tpep_pickup_datetime) AS INTEGER) AS year,
                   CAST(strftime('%m', tpep_pickup_datetime) AS INTEGER) AS month,
                   CAST(strftime('%d', tpep_pickup_datetime) AS INTEGER) AS day,
                   CAST(tpep_pickup_datetime AS TIMESTAMP) AS pickup_dt
            FROM trips_coord
            WHERE tpep_pickup_datetime >= '2023-01-01'
              AND tpep_pickup_datetime <  '2024-01-01'
        """)
        self.con.execute("DROP TABLE trips_coord;")

        # Print counts by year/month/day
        for comp in ['year', 'month', 'day']:
            print(f"Counts by {comp}:")
            df_counts = self.con.execute(f"SELECT {comp}, COUNT(*) AS cnt FROM trips_dt GROUP BY {comp} ORDER BY {comp}").fetchdf()
            print(df_counts)

    def filter_map_boundary(self, min_lat: float, max_lat: float, min_lon: float, max_lon: float):
        """
        Step 7.2: Filter records outside given latitude/longitude bounds.
        """
        self.con.execute(f"""
            CREATE TABLE trips_bound AS
            SELECT *
            FROM trips_dt
            WHERE pickup_lat BETWEEN {min_lat} AND {max_lat}
              AND dropoff_lat BETWEEN {min_lat} AND {max_lat}
              AND pickup_lon BETWEEN {min_lon} AND {max_lon}
              AND dropoff_lon BETWEEN {min_lon} AND {max_lon}
        """)
        self.con.execute("DROP TABLE trips_dt;")

    def explore_distribution(self):
        """
        Step 1 (optional): Compute descriptive stats on numeric columns.
        """
        available_cols = [col for col in ['fare_amount','total_amount','extra','airport_fee']
                          if self.con.execute("PRAGMA table_info(trips_bound)").fetchdf()['name'].tolist().count(col)]
        if not available_cols:
            print("No relevant numeric columns found for descriptive statistics.")
            return
        cols_str = ", ".join(available_cols)
        desc_query = f"SELECT \n"
        # Compute min, max, avg for each column
        for col in available_cols:
            desc_query += f"MIN({col}) AS {col}_min, MAX({col}) AS {col}_max, AVG({col}) AS {col}_avg, "
        desc_query = desc_query.rstrip(', ')
        desc_query += " FROM trips_bound;"
        df_desc = self.con.execute(desc_query).fetchdf()
        print("Descriptive statistics:")
        print(df_desc)

    # def map_plot(self, map_file: str = "taxi_map.html", sample_frac: float = 0.01):
    #     """
    #     Step 7.3: Sample points and plot on Folium map.
    #     """
    #     # Count total rows to compute sample size
    #     total_rows = self.con.execute("SELECT COUNT(*) AS cnt FROM trips_bound").fetchdf()['cnt'][0]
    #     sample_size = max(int(total_rows * sample_frac), 1)
    #     # df_sample = self.con.execute(f"SELECT pickup_lat, pickup_lon, dropoff_lat, dropoff_lon FROM trips_bound SAMPLE {sample_size}").fetchdf()
    #     df_sample = self.con.execute(
    #         f"SELECT pickup_lat, pickup_lon, dropoff_lat, dropoff_lon FROM trips_bound ORDER BY RANDOM() LIMIT {sample_size};"
    #     ).fetchdf()
    #     # Determine map center
    #     min_lat, max_lat, min_lon, max_lon = self.con.execute(
    #         "SELECT MIN(pickup_lat), MAX(pickup_lat), MIN(pickup_lon), MAX(pickup_lon) FROM trips_bound"
    #     ).fetchone()
    #     center = [(min_lat + max_lat) / 2, (min_lon + max_lon) / 2]

    #     m = folium.Map(location=center, zoom_start=12)
    #     for _, row in df_sample.iterrows():
    #         folium.CircleMarker(location=[row['pickup_lat'], row['pickup_lon']], radius=2, fill=True, fill_opacity=0.6).add_to(m)
    #         folium.CircleMarker(location=[row['dropoff_lat'], row['dropoff_lon']], radius=2, color='red', fill=True, fill_opacity=0.6).add_to(m)
    #     m.save(map_file)
    #     print(f"Map saved to {map_file}")

    def coordinate_frequency(self, top_n: int = 20):
        """
        Step 7.4: Compute frequency of each coordinate pair and print top N.
        """
        df_freq = self.con.execute(f"""
            SELECT lat, lon, cnt FROM (
                SELECT pickup_lat AS lat, pickup_lon AS lon, COUNT(*) AS cnt
                FROM trips_bound GROUP BY pickup_lat, pickup_lon
                UNION ALL
                SELECT dropoff_lat AS lat, dropoff_lon AS lon, COUNT(*) AS cnt
                FROM trips_bound GROUP BY dropoff_lat, dropoff_lon
            ) ORDER BY cnt DESC LIMIT {top_n};
        """).fetchdf()
        print(df_freq)

    # def save_data(self):
    #     """
    #     Step 7.5: Persist the final filtered table to disk in desired format.
    #     """
    #     if not self.output_path:
    #         raise ValueError("No output_path specified for saving data.")
    #     if self.output_format == 'parquet':
    #         self.con.execute(f"COPY trips_bound TO '{self.output_path}' (FORMAT PARQUET);")
    #     elif self.output_format == 'csv':
    #         self.con.execute(f"COPY trips_bound TO '{self.output_path}' (HEADER TRUE);")
    #     else:
    #         raise ValueError(f"Unsupported output format: {self.output_format}")
    #     print(f"Data saved to {self.output_path} as {self.output_format}.")
    # def save_data(self):
    #     """
    #     Save final filtered table to the processed_output directory as <original_name>_new.parquet
    #     """
    #     import os, glob
    #     # Use first matching parquet file to get original base name
    #     matched = glob.glob(self.parquet_pattern)
    #     if matched:
    #         orig = os.path.splitext(os.path.basename(matched[0]))[0]
    #     else:
    #         orig = 'data'
    #     new_name = orig + '_new.parquet'
    #     # Ensure output_path is directory
    #     out_dir = self.output_path.rstrip(os.sep)
    #     os.makedirs(out_dir, exist_ok=True)
    #     file_path = os.path.join(out_dir, new_name)
    #     print(f"Writing Parquet to: {file_path}")
    #     self.con.execute(f"COPY trips_bound TO '{file_path}' (FORMAT PARQUET);")
    #     print(f"Data saved to {file_path}")
    def save_data_by_month(self):
      """
      按照 tpep_pickup_datetime 的月份，将 trips_bound 表拆分成多个 Parquet 文件并依次写入磁盘，
      每个文件只包含该月数据，且在 COPY 时使用较小的 ROW_GROUP_SIZE 以降低单次写入的内存峰值。
      最终输出文件形如：<orig>_2023-01_new.parquet, <orig>_2023-02_new.parquet, …。
      """

      # 1. 从 parquet_pattern（原始文件模式）推断一个“原始 base 名称”：
      matched = glob.glob(self.parquet_pattern)
      if matched:
          orig = os.path.splitext(os.path.basename(matched[0]))[0]
      else:
          orig = 'data'

      # 2. 确保 output_path 存在且是目录：
      out_dir = self.output_path.rstrip(os.sep)
      os.makedirs(out_dir, exist_ok=True)

      # 3. 在 DuckDB 中查询 trips_bound 里最小和最大 pickup 时间：
      result = self.con.execute("""
          SELECT
            MIN(tpep_pickup_datetime) AS min_dt,
            MAX(tpep_pickup_datetime) AS max_dt
          FROM trips_bound;
      """).fetchall()
      if not result:
          print("❌ trips_bound 表为空，或查询失败")
          return

      min_dt_val, max_dt_val = result[0]  # 可能是 datetime，也可能是字符串

      # 4. 如果是字符串就用 fromisoformat，否则认为它已经是 datetime
      if isinstance(min_dt_val, str):
          min_dt = datetime.fromisoformat(min_dt_val).date()
      else:
          min_dt = min_dt_val.date()

      if isinstance(max_dt_val, str):
          max_dt = datetime.fromisoformat(max_dt_val).date()
      else:
          max_dt = max_dt_val.date()

      # 5. 根据 min_dt 和 max_dt，生成每个月的区间 [当月第一天, 下月第一天)：
      #    先把 min_dt 向下取到当月第一天：
      start_month = date(min_dt.year, min_dt.month, 1)
      #    把 max_dt 向上取到下月第一天（作为循环终止条件，不包含该日）：
      if max_dt.month == 12:
          end_month = date(max_dt.year + 1, 1, 1)
      else:
          end_month = date(max_dt.year, max_dt.month + 1, 1)

      curr = start_month
      while curr < end_month:
          # 计算下一个月的“第一天”
          if curr.month == 12:
              next_month = date(curr.year + 1, 1, 1)
          else:
              next_month = date(curr.year, curr.month + 1, 1)

          # 格式化成 "YYYY-MM"
          month_str = curr.strftime("%Y-%m")  # 例如 "2023-01"
          # 输出文件名示例： "tripdata_2023-01_new.parquet"
          file_name = f"{orig}_{month_str}_new.parquet"
          file_path = os.path.join(out_dir, file_name)

          print(f"▶ 正在导出 {month_str} … 输出到：{file_path}")

          # 6. 仅导出当前月的数据到 Parquet，并且指定 ROW_GROUP_SIZE
          #    这里示例取 ROW_GROUP_SIZE=500000，根据你的内存情况可以适当调整
          sql = f"""
          COPY (
              SELECT *
              FROM trips_bound
              WHERE tpep_pickup_datetime >= '{curr.isoformat()}'
                AND tpep_pickup_datetime <  '{next_month.isoformat()}'
          )
          TO '{file_path}'
          (FORMAT PARQUET, ROW_GROUP_SIZE 500000);
          """
          self.con.execute(sql)

          print(f"  ✔  {month_str} 导出完成")
          curr = next_month  # 进入下一个月

      print("✅ 按月导出完成，文件保存在：", out_dir)



    def run_all(self,
                bounds: dict = None,
                map_file: str = 'taxi_map.html',
                sample_frac: float = 0.01,
                top_n: int = 20):
        """
        Execute all steps in order:
        1. load_and_clean()
        2. map_coordinates()
        3. process_datetime()
        4. filter_map_boundary()
        5. explore_distribution()  (optional)
        6. map_plot()
        7. coordinate_frequency()
        # 8. save_data()
        # 8. save_data_by_month()
        """
        self.load_and_clean()
        self.map_coordinates()
        self.process_datetime()
        if bounds:
            self.filter_map_boundary(**bounds)
        self.explore_distribution()
        # self.map_plot(map_file=map_file, sample_frac=sample_frac)
        self.coordinate_frequency(top_n=top_n)
        # if self.output_path:
        #     self.save_data()

# === Usage Example ===
if __name__ == '__main__':
    processor = TaxiDataProcessorDuckDB(
        parquet_pattern='/content/drive/MyDrive/NYC_yellow_taxi/process_data/yellow_tripdata_*.parquet',
        zone_csv='/content/drive/MyDrive/NYC_yellow_taxi/taxi_zone_with_coords.csv',
        output_path='/content/drive/MyDrive/NYC_yellow_taxi/processed_output/',
        output_format='parquet'
    )
    NYC_BOUNDS = {
        'min_lat': 40.49612,
        'max_lat': 40.91553,
        'min_lon': -74.25559,
        'max_lon': -73.70001
    }
    processor.run_all(bounds=NYC_BOUNDS,
                     map_file='/content/drive/MyDrive/NYC_yellow_taxi/processed_output/nyc_taxi_map.html',
                     sample_frac=0.01,
                     top_n=20)
    # 然后单独保存（此时内存压力更可控）
    processor.save_data_by_month()


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

Counts by year:
   year       cnt
0  2023  37293341
Counts by month:
    month      cnt
0       1  2983437
1       2  2835741
2       3  3320229
3       4  3203374
4       5  3422757
5       6  3218603
6       7  2823251
7       8  2742964
8       9  2767471
9      10  3428568
10     11  3252376
11     12  3294570
Counts by day:
    day      cnt
0     1  1230751
1     2  1202267
2     3  1206924
3     4  1177068
4     5  1222834
5     6  1244700
6     7  1287081
7     8  1263686
8     9  1254602
9    10  1272168
10   11  1286273
11   12  1295478
12   13  1300791
13   14  1342516
14   15  1328103
15   16  1279424
16   17  1262063
17   18  1282876
18   19  1258602
19   20  1241801
20   21  1191162
21   22  1122093
22   23  1051892
23   24  1082818
24   25  1168500
25   26  1185466
26   27  1213972
27   28  1241312
28   29  1066896
29   30  1061564
30   31   667658


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

Descriptive statistics:
   fare_amount_min  fare_amount_max  fare_amount_avg  total_amount_min  \
0             0.01        386983.63        19.533706              0.01   

   total_amount_max  total_amount_avg  extra_min  extra_max  extra_avg  \
0         386987.63         28.565506        0.0      96.38   1.582454   

   airport_fee_min  airport_fee_max  airport_fee_avg  
0              0.0             1.75         0.140856  


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

          lat        lon      cnt
0   40.773600 -73.956600  3352795
1   40.773600 -73.956600  3252646
2   40.642948 -73.779373  1873314
3   40.765064 -73.985319  1742458
4   40.765064 -73.985319  1467630
5   40.755361 -73.967411  1380449
6   40.759822 -73.972471  1337012
7   40.737463 -74.001456  1324473
8   40.737463 -74.001456  1319178
9   40.742165 -73.988121  1286551
10  40.775714 -73.873364  1267929
11  40.759508 -73.984159  1247561
12  40.774000 -73.982000  1242133
13  40.755361 -73.967411  1187346
14  40.759508 -73.984159  1160788
15  40.748157 -73.978750  1117236
16  40.748157 -73.978750  1116325
17  40.764840 -73.985172  1094601
18  40.759822 -73.972471  1071186
19  40.774000 -73.982000  1057725


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

▶ 正在导出 2023-01 … 输出到：/content/drive/MyDrive/NYC_yellow_taxi/processed_output/yellow_tripdata_2023-01_2023-01_new.parquet


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  ✔  2023-01 导出完成
▶ 正在导出 2023-02 … 输出到：/content/drive/MyDrive/NYC_yellow_taxi/processed_output/yellow_tripdata_2023-01_2023-02_new.parquet


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  ✔  2023-02 导出完成
▶ 正在导出 2023-03 … 输出到：/content/drive/MyDrive/NYC_yellow_taxi/processed_output/yellow_tripdata_2023-01_2023-03_new.parquet


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  ✔  2023-03 导出完成
▶ 正在导出 2023-04 … 输出到：/content/drive/MyDrive/NYC_yellow_taxi/processed_output/yellow_tripdata_2023-01_2023-04_new.parquet


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  ✔  2023-04 导出完成
▶ 正在导出 2023-05 … 输出到：/content/drive/MyDrive/NYC_yellow_taxi/processed_output/yellow_tripdata_2023-01_2023-05_new.parquet


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  ✔  2023-05 导出完成
▶ 正在导出 2023-06 … 输出到：/content/drive/MyDrive/NYC_yellow_taxi/processed_output/yellow_tripdata_2023-01_2023-06_new.parquet


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  ✔  2023-06 导出完成
▶ 正在导出 2023-07 … 输出到：/content/drive/MyDrive/NYC_yellow_taxi/processed_output/yellow_tripdata_2023-01_2023-07_new.parquet


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  ✔  2023-07 导出完成
▶ 正在导出 2023-08 … 输出到：/content/drive/MyDrive/NYC_yellow_taxi/processed_output/yellow_tripdata_2023-01_2023-08_new.parquet


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  ✔  2023-08 导出完成
▶ 正在导出 2023-09 … 输出到：/content/drive/MyDrive/NYC_yellow_taxi/processed_output/yellow_tripdata_2023-01_2023-09_new.parquet


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  ✔  2023-09 导出完成
▶ 正在导出 2023-10 … 输出到：/content/drive/MyDrive/NYC_yellow_taxi/processed_output/yellow_tripdata_2023-01_2023-10_new.parquet


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  ✔  2023-10 导出完成
▶ 正在导出 2023-11 … 输出到：/content/drive/MyDrive/NYC_yellow_taxi/processed_output/yellow_tripdata_2023-01_2023-11_new.parquet


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  ✔  2023-11 导出完成
▶ 正在导出 2023-12 … 输出到：/content/drive/MyDrive/NYC_yellow_taxi/processed_output/yellow_tripdata_2023-01_2023-12_new.parquet


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

  ✔  2023-12 导出完成
✅ 按月导出完成，文件保存在： /content/drive/MyDrive/NYC_yellow_taxi/processed_output


In [5]:
import gc
gc.collect()

69

In [None]:
self.con.close()