In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StringType, IntegerType, DoubleType, DateType

from io import StringIO
import pandas as pd
import boto3

from azure.identity import ClientSecretCredential
from azure.storage.blob import BlobServiceClient


In [6]:
# ---------- Spark session ----------

spark = SparkSession.builder \
    .appName("GTFS Load") \
    .config(
        "spark.jars.packages",
        "org.postgresql:postgresql:42.7.7"
    ) \
    .config("spark.executor.cores", "10") \
    .config("spark.driver.memory", "2g") \
    .config("spark.executor.memory", "6g") \
    .getOrCreate()

In [None]:
# ---------- Read Latest CSV from Azure Data Lake To Pandas DataFrames ----------

# Azure Config
TENANT_ID = ''
CLIENT_ID = ''
CLIENT_SECRET = ''
ACCOUNT_NAME = 'gtfsdls'
CONTAINER_NAME = 'transitbatchlatest'


# Connect to Azure Blob
account_url = f"https://{ACCOUNT_NAME}.blob.core.windows.net"
credential = ClientSecretCredential(TENANT_ID, CLIENT_ID, CLIENT_SECRET)
blob_service_client = BlobServiceClient(account_url=account_url, credential=credential)
container_client = blob_service_client.get_container_client(CONTAINER_NAME)

def read_csv_from_adls(blob_name):
    try:
        blob_client = container_client.get_blob_client(blob_name)
        data = blob_client.download_blob().readall()
        return pd.read_csv(StringIO(data.decode("utf-8")))
    except Exception:
        return pd.DataFrame()

def find_and_read_gtfs_files(base_path="rowdata/latest"):
    gtfs_files = {
        "agency_df": "agency.csv",
        "calendar_df": "calendar.csv", 
        "calendar_dates_df": "calendar_dates.csv",
        "routes_df": "routes.csv",
        "shapes_df": "shapes.csv",
        "stops_df": "stops.csv",
        "stop_times_df": "stop_times.csv",
        "trips_df": "trips.csv"
    }

    # List all blobs in the container with the given prefix
    blob_list = list(container_client.list_blobs(name_starts_with=base_path))
    all_files = [b.name for b in blob_list]

    result = {}
    for var_name, filename in gtfs_files.items():
        matching = [f for f in all_files if f.endswith(filename)]
        if matching:
            dfs = [read_csv_from_adls(f) for f in matching]
            df = pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()
        else:
            df = pd.DataFrame()
        result[var_name] = df
    
    return result

# Usage
data = find_and_read_gtfs_files()

agency_df         = data["agency_df"]
calendar_df       = data["calendar_df"]
calendar_dates_df = data["calendar_dates_df"]
routes_df         = data["routes_df"]
shapes_df         = data["shapes_df"]
stops_df          = data["stops_df"]
stop_times_df     = data["stop_times_df"]
trips_df          = data["trips_df"]


# # ---------- Read Latest CSV from MinIO To Pandas DataFrames ----------

# # MinIO Config Connection
# s3_client = boto3.client(
#     's3',
#     endpoint_url='http://minio:9000',
#     aws_access_key_id='minio',
#     aws_secret_access_key='minio123'
# )

# def read_csv_from_minio(bucket, key):
#     try:
#         response = s3_client.get_object(Bucket=bucket, Key=key)
#         return pd.read_csv(StringIO(response['Body'].read().decode('utf-8')))
#     except Exception:
#         return pd.DataFrame()

# def find_and_read_gtfs_files(bucket, base_path="rowdata/latest"):
#     gtfs_files = {
#         "agency_df": "agency.csv",
#         "calendar_df": "calendar.csv", 
#         "calendar_dates_df": "calendar_dates.csv",
#         "routes_df": "routes.csv",
#         "shapes_df": "shapes.csv",
#         "stops_df": "stops.csv",
#         "stop_times_df": "stop_times.csv",
#         "trips_df": "trips.csv"
#     }
    
#     response = s3_client.list_objects_v2(Bucket=bucket, Prefix=base_path)
#     all_files = [obj['Key'] for obj in response.get('Contents', [])]
    
#     result = {}
#     for var_name, filename in gtfs_files.items():
#         matching = [f for f in all_files if f.endswith(filename)]
#         if matching:
#             df = pd.concat([read_csv_from_minio(bucket, f) for f in matching], ignore_index=True)
#         else:
#             df = pd.DataFrame()
#         result[var_name] = df
    
#     return result

# bucket = "transitbatchlatest"
# data = find_and_read_gtfs_files(bucket)

# agency_df         = data["agency_df"]
# calendar_df       = data["calendar_df"]
# calendar_dates_df = data["calendar_dates_df"]
# routes_df         = data["routes_df"]
# shapes_df         = data["shapes_df"]
# stops_df          = data["stops_df"]
# stop_times_df     = data["stop_times_df"]
# trips_df          = data["trips_df"]


In [None]:
# ---------- Convert Pands DataFrames to Spark DataFrames with Schema ----------

agency_schema_map = {
    "agency_id": StringType(),
    "agency_name": StringType(),
    "agency_url": StringType(),
    "agency_timezone": StringType(),
    "agency_lang": StringType(),
    "agency_phone": StringType()
}

calendar_schema_map = {
    "service_id": StringType(),
    "monday": IntegerType(),
    "tuesday": IntegerType(),
    "wednesday": IntegerType(),
    "thursday": IntegerType(),
    "friday": IntegerType(),
    "saturday": IntegerType(),
    "sunday": IntegerType(),
    "start_date": IntegerType(),  
    "end_date": IntegerType()     
}

calendar_dates_schema_map = {
    "service_id": StringType(),
    "date": IntegerType(),        
    "exception_type": IntegerType()
}

routes_schema_map = {
    "route_id": StringType(),
    "agency_id": StringType(),
    "route_short_name": StringType(),
    "route_long_name": StringType(),
    "route_desc": StringType(),
    "route_type": IntegerType(),
    "route_color": StringType(),
    "route_text_color": StringType()
}

shapes_schema_map = {
    "shape_id": StringType(),
    "shape_pt_lat": DoubleType(),
    "shape_pt_lon": DoubleType(),
    "shape_pt_sequence": IntegerType()
}

stops_schema_map = {
    "stop_id": IntegerType(),      
    "stop_name": StringType(),
    "stop_desc": StringType(),
    "stop_lat": DoubleType(),
    "stop_lon": DoubleType(),
    "zone_id": StringType(),
    "stop_url": StringType(),
    "location_type": IntegerType(),
    "parent_station": StringType()
}

stop_times_schema_map = {
    "trip_id": StringType(),
    "arrival_time": StringType(),
    "departure_time": StringType(),
    "stop_id": IntegerType(),     
    "stop_sequence": IntegerType(),
    "pickup_type": IntegerType(),
    "drop_off_type": IntegerType(),
    "timepoint": IntegerType()
}

trips_schema_map = {
    "route_id": StringType(),
    "service_id": StringType(),
    "trip_id": StringType(),
    "trip_headsign": StringType(),
    "direction_id": IntegerType(),
    "block_id": IntegerType(),     
    "shape_id": StringType()
}

# Conversion Function with Drop Duplication
def pandas_to_spark_with_schema(pdf, schema_map):
    if pdf.empty:
        return spark.createDataFrame([], schema=schema_map)
    
    sdf = spark.createDataFrame(pdf)
    
    for col_name in sdf.columns:
        sdf = sdf.withColumn(col_name, F.trim(F.col(col_name)))
    
    for col_name, dtype in schema_map.items():
        if col_name in sdf.columns:
            sdf = sdf.withColumn(col_name, F.col(col_name).cast(dtype))
    
    return sdf.dropDuplicates()

agency_sdf         = pandas_to_spark_with_schema(agency_df, agency_schema_map)
calendar_sdf       = pandas_to_spark_with_schema(calendar_df, calendar_schema_map)
calendar_dates_sdf = pandas_to_spark_with_schema(calendar_dates_df, calendar_dates_schema_map)
routes_sdf         = pandas_to_spark_with_schema(routes_df, routes_schema_map)
shapes_sdf         = pandas_to_spark_with_schema(shapes_df, shapes_schema_map)
stops_sdf          = pandas_to_spark_with_schema(stops_df, stops_schema_map)
stop_times_sdf     = pandas_to_spark_with_schema(stop_times_df, stop_times_schema_map)
trips_sdf          = pandas_to_spark_with_schema(trips_df, trips_schema_map)



In [None]:
#=================================================================================================================
#=================================================================================================================
#=================================================================================================================
#=================================================================================================================
#=================================================================================================================
#=================================================================================================================
#=================================================================================================================
#=================================================================================================================

In [None]:
# ---------- Convert columns from IntegerType to DateType ----------

calendar_sdf = calendar_sdf.withColumn("start_date", F.to_date(F.col("start_date").cast("string"), "yyyyMMdd")) \
                           .withColumn("end_date", F.to_date(F.col("end_date").cast("string"), "yyyyMMdd"))

calendar_dates_sdf = calendar_dates_sdf.withColumn("date", F.to_date(F.col("date").cast("string"), "yyyyMMdd"))


In [None]:
# ---------- Replace Nan or "NaN" to NULL ----------

agency_sdf         = agency_sdf.replace("NaN", None)
calendar_sdf       = calendar_sdf.replace("NaN", None)
calendar_dates_sdf = calendar_dates_sdf.replace("NaN", None)
routes_sdf         = routes_sdf.replace("NaN", None)
shapes_sdf         = shapes_sdf.replace("NaN", None)
stops_sdf          = stops_sdf.replace("NaN", None)
stop_times_sdf     = stop_times_sdf.replace("NaN", None)
trips_sdf          = trips_sdf.replace("NaN", None)


In [None]:
# ---------- Add control columns ----------
start_dt_scd = F.current_timestamp()

stg_agency_df         = agency_sdf.withColumn("start_dt_scd", start_dt_scd) \
                            .withColumn("end_dt_scd", F.lit(None).cast("timestamp")) \
                            .withColumn("is_current", F.lit(True)) 

stg_calendar_df       = calendar_sdf.withColumn("start_dt_scd", start_dt_scd) \
                                .withColumn("end_dt_scd", F.lit(None).cast("timestamp")) \
                                .withColumn("is_current", F.lit(True))

stg_calendar_dates_df = calendar_dates_sdf.withColumn("start_dt_scd", start_dt_scd) \
                                    .withColumn("end_dt_scd", F.lit(None).cast("timestamp")) \
                                    .withColumn("is_current", F.lit(True)) 

stg_routes_df         = routes_sdf.withColumn("start_dt_scd", start_dt_scd) \
                            .withColumn("end_dt_scd", F.lit(None).cast("timestamp")) \
                            .withColumn("is_current", F.lit(True))

stg_shapes_df         = shapes_sdf.withColumn("start_dt_scd", start_dt_scd) \
                            .withColumn("end_dt_scd", F.lit(None).cast("timestamp")) \
                            .withColumn("is_current", F.lit(True))

stg_stops_df          = stops_sdf.withColumn("start_dt_scd", start_dt_scd) \
                            .withColumn("end_dt_scd", F.lit(None).cast("timestamp")) \
                            .withColumn("is_current", F.lit(True))

stg_stop_times_df     = stop_times_sdf.withColumn("start_dt_scd", start_dt_scd) \
                                .withColumn("end_dt_scd", F.lit(None).cast("timestamp")) \
                                .withColumn("is_current", F.lit(True))

stg_trips_df          = trips_sdf.withColumn("start_dt_scd", start_dt_scd) \
                            .withColumn("end_dt_scd", F.lit(None).cast("timestamp")) \
                            .withColumn("is_current", F.lit(True)) 


In [None]:
#=================================================================================================================
#=================================================================================================================
#=================================================================================================================
#=================================================================================================================
#=================================================================================================================
#=================================================================================================================
#=================================================================================================================
#=================================================================================================================

In [None]:
# Postgres Config Connection

jdbc_url = "jdbc:postgresql://postgres:5432/gtfs_batch_staging"
connection_properties = {
     "user": "admin",
     "password": "password",
     "driver": "org.postgresql.Driver"
 }

In [None]:
stg_agency_df.write.jdbc(url=jdbc_url, table="stg_agency", mode="overwrite", properties=connection_properties)
stg_calendar_df.write.jdbc(url=jdbc_url, table="stg_calendar", mode="overwrite", properties=connection_properties)
stg_calendar_dates_df.write.jdbc(url=jdbc_url, table="stg_calendar_dates", mode="overwrite", properties=connection_properties)
stg_routes_df.write.jdbc(url=jdbc_url, table="stg_routes", mode="overwrite", properties=connection_properties)
stg_shapes_df.write.jdbc(url=jdbc_url, table="stg_shapes", mode="overwrite", properties=connection_properties)
stg_stops_df.write.jdbc(url=jdbc_url, table="stg_stops", mode="overwrite", properties=connection_properties)
stg_trips_df.write.jdbc(url=jdbc_url, table="stg_trips", mode="overwrite", properties=connection_properties)

                                                                                

In [None]:
stg_stop_times_df.repartition(50).write.jdbc(url=jdbc_url, table="stg_stop_times", mode="overwrite", properties=connection_properties)

                                                                                

In [None]:
import gc
for df in [
    agency_sdf, calendar_sdf, calendar_dates_sdf, routes_sdf,
    shapes_sdf, stops_sdf, stop_times_sdf, trips_sdf,
    stg_agency_df, stg_calendar_df, stg_calendar_dates_df, stg_routes_df,
    stg_shapes_df, stg_stops_df, stg_stop_times_df, stg_trips_df
]:
    df.unpersist(blocking=True)

# Clear cache from Spark
spark.catalog.clearCache()

# Stop Spark
spark.stop()

# Deleting variables from memory
del agency_sdf, calendar_sdf, calendar_dates_sdf, routes_sdf
del shapes_sdf, stops_sdf, stop_times_sdf, trips_sdf
del stg_agency_df, stg_calendar_df, stg_calendar_dates_df, stg_routes_df
del stg_shapes_df, stg_stops_df, stg_stop_times_df, stg_trips_df, spark

# Garbage collection imposed
gc.collect()