In [None]:
import sys

from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, to_timestamp, date_trunc, row_number, split,
    regexp_extract, regexp_replace, when,
    monotonically_increasing_id, broadcast
)
from pyspark.sql import functions as F
from pyspark.sql import Window

def main():
    # Expecting three parameters:
    # 1. Citibike rides CSV file paths
    # 2. Air quality CSV file path
    # 3. Output directory path for the final joined data
    if len(sys.argv) != 4:
        print("Usage: spark-submit spark-job.py [citibike_paths (comma-separated)] [air_quality_path] [output_path]")
        sys.exit(1)

    # Parse command-line arguments
    citibike_paths = sys.argv[1].split(",")
    air_quality_path = sys.argv[2]
    output_path = sys.argv[3]

    spark = SparkSession.builder.getOrCreate()

    ######################################################
    # Load and preprocess Citibike rides data
    ######################################################
    rides = (
        spark.read
             .option("header", "true")
             .option("recursiveFileLookup", "true")
             .option("spark.sql.files.maxPartitionBytes", "128MB")
             .csv(citibike_paths)
             .select(
                 col("tripduration").alias("trip_duration"),
                 col("starttime").alias("trip_start_time"),
                 col("startstationid").alias("start_station_id"),
                 col("startstationlatitude").alias("start_station_latitude"),
                 col("startstationlongitude").alias("start_station_longitude"),
                 col("endstationid").alias("end_station_id"),
                 col("usertype").alias("user_type")
             )
    )

    # Determine borough based on starting station coordinates.
    rides = rides.withColumn(
        "borough",
        when(
            (col("start_station_latitude").between(40.70, 40.75)) &
            (col("start_station_longitude").between(-74.02, -73.93)),
            "Downtown Manhattan"
        ).when(
            (col("start_station_latitude").between(40.75, 40.82)) &
            (col("start_station_longitude").between(-74.02, -73.93)),
            "Midtown Manhattan"
        ).when(
            (col("start_station_latitude").between(40.82, 40.88)) &
            (col("start_station_longitude").between(-74.02, -73.93)),
            "Uptown Manhattan"
        ).when(
            (col("start_station_latitude").between(40.57, 40.70)) &
            (col("start_station_longitude").between(-74.05, -73.85)),
            "Brooklyn"
        ).when(
            (col("start_station_latitude").between(40.70, 40.80)) &
            (col("start_station_longitude").between(-73.93, -73.70)),
            "Queens"
        ).when(
            (col("start_station_latitude") >= 40.88) & (col("start_station_latitude") < 41.00) &
            (col("start_station_longitude").between(-73.93, -73.70)),
            "The Bronx"
        ).otherwise("Other")
    )

    # Standardize borough names to "zone" for consistency with air quality data.
    rides = rides.withColumn(
        "zone",
        when(col("borough") == "Downtown Manhattan", "lower-manhattan")
        .when(col("borough") == "Midtown Manhattan", "mid-manhattan")
        .when(col("borough") == "Uptown Manhattan", "upper-manhattan")
        .when(col("borough") == "Brooklyn", "Brooklyn")
        .when(col("borough") == "Queens", "Queens")
        .when(col("borough") == "The Bronx", "Bronx")
        .otherwise("Other")
    )

    # Convert trip_start_time to timestamp for potential further processing.
    rides = rides.withColumn(
        "trip_start_timestamp", to_timestamp(col("trip_start_time"), "MM/dd/yyyy HH:mm:ss")
    )

    # Write the processed rides data (with zones) for separate usage or debugging.
    rides.write \
        .option("maxRecordsPerFile", 100000) \
        .mode("overwrite") \
        .parquet("../output/citibike_zones")

    ######################################################
    # Load and preprocess Air Quality data
    ######################################################
    air_quality = (
        spark.read
            .option("header", "true")
            .option("recursiveFileLookup", "true")
            .option("spark.sql.files.maxPartitionBytes", "128MB")
            .csv(air_quality_path)
            .select(
                col("indicator_id"),
                col("name").alias("pollutants"),
                col("measure"),
                col("geo_type_name"),
                col("geo_join_id"),
                col("geo_place_name"),
                col("start_date"),
            )
            .withColumn("start_date", to_date(col("start_date"), "yyyy/MM/dd"))
            .filter(col("start_date").between(
                F.to_date(F.lit("2017/01/01"), "yyyy/MM/dd"),
                F.to_date(F.lit("2019/12/31"), "yyyy/MM/dd")
            ))
            # Retain only records corresponding to UHF 42
            .filter(col("geo_type_name") == "UHF 42")
    )

    # Map UHF zone numbers to broader district names
    air_quality = air_quality.withColumn(
        "zone",
        F.when(col("geo_join_id").isin([308, 309, 310]), "lower-manhattan")
         .when(col("geo_join_id").isin([306, 307]), "mid-manhattan")
         .when(col("geo_join_id").isin([301, 302, 303, 304, 305]), "upper-manhattan")
         .when(col("geo_join_id").isin([201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211]), "Brooklyn")
         .when(col("geo_join_id").isin([101, 102, 103, 104, 105, 106, 107]), "Bronx")
         .when(col("geo_join_id").isin([401, 402, 403, 404, 405, 406, 407, 408, 409, 410]), "Queens")
         .when(col("geo_join_id").isin([501, 502, 503, 504]), "Staten Island")
         .otherwise("Unknown")
    )

    ######################################################
    # Join rides data with air quality data on zone only
    ######################################################

from pyspark.sql import functions as F #新增，添加计算函数用

#新增，进行时间转换，用于匹配rides和air quality的年月日
rides = rides.withColumn(
    "trip_start_date", F.date_format(F.col("trip_start_timestamp"), "yyyy-MM-dd")
)
air_quality = air_quality.withColumn(
    "start_date_formatted", F.date_format(F.to_date(F.col("start_date"), "M/d/yyyy"), "yyyy-MM-dd")
)

#Join rides data with air quality data on zone and time
final = rides.join(
    air_quality,
    (rides.zone == air_quality.zone) & (rides.trip_start_date == air_quality.start_date_formatted), #新增，年月日匹配
    "left"
).select(
    rides["*"],
    air_quality["pollutants"],
    air_quality["measure"],
    #air_quality["time_period"],  **去除time_period
    air_quality["geo_join_id"],
    air_quality["geo_place_name"]
)

#新增，取平均值Average the measure value to remove duplicates
final = final.groupBy("trip_start_time", "trip_start_date", "zone", "pollutants") \
    .agg(F.avg("measure").alias("measure"),  #计算measure平均值
         F.first("geo_join_id").alias("geo_join_id"),
         F.first("geo_place_name").alias("geo_place_name")
    )

    # Write final joined data to parquet, limiting each file to 100,000 records.
final.write \
    .option("maxRecordsPerFile", 100000) \
    .mode("overwrite") \
    .parquet(output_path)

if __name__ == "__main__":
    main()

NameError: name 'rides' is not defined