In [None]:
# Imports and constants
import pandas as pd
from sqlalchemy.sql import text
from lib import Output, TaxiDBReader, GREEN, YELLOW, TABLES, tabulate, CHUNK_SIZE, TABLE_FORMAT, SEPARATOR

O = Output('output/eda_output.txt') 
    
O.clear()
O.out(SEPARATOR)
O.out('EDA (text)')
O.out(SEPARATOR)
dr = TaxiDBReader()

# dr.setTable(year=2023, taxi_type=GREEN)

In [None]:
def runBasicStatistics(year, taxi_type):
    # basic statistics
    # note this will take long because we're trying to read the whole table
    dr.setTable(year, taxi_type)
    table_name = dr.getTableName()
    O.out(SEPARATOR)
    O.out(f'Basic Statistics - {year} {taxi_type} Taxi Trips')
    with dr.engn.connect() as conn:
        sql = text(f'SELECT * FROM {table_name}')
        df = pd.read_sql(sql, conn)
        
        O.out(tabulate(df.describe(), headers='keys', tablefmt=TABLE_FORMAT))
        with open(O.path, 'a') as f:
            df.info(buf=f)

# runBasicStatistics(2023, GREEN)

In [None]:
def runOtherStatistics(year, taxi_type): 
    dr.setTable(year, taxi_type)
    table_name = dr.getTableName()
    O.out(SEPARATOR)
    O.out(f'Other Statistics - {year} {taxi_type} Taxi Trips')
    with dr.engn.connect() as conn:
        # start = 0
        # count = 10
        # sql = text(f'SELECT * FROM {table_name} LIMIT {count} OFFSET {start} ROWS FETCH NEXT {count} ROWS ONLY')
        # month = 1 # january
        # day = 1
        # hour = 1
        # sql = text(f'select * from {table_name} where strftime(\'%m\', pickup_datetime) = \'0{month}\'')
        
        # average passenger count by pickup location id
        sql = text(f'''
        SELECT pu_location_id, avg(passenger_count) AS avg_passenger_count FROM {table_name}
        WHERE passenger_count > 0 GROUP BY pu_location_id
        ''')
        O.out(sql)
        O.out(pd.read_sql(sql, conn), is_df=True)

        # number of trips by pickup location where no passengers where recorded
        sql = text(f'''
        SELECT pu_location_id, COUNT(passenger_count) AS count_passenger_count FROM {table_name} 
        WHERE passenger_count = 0 GROUP BY pu_location_id
        ''')
        O.out(SEPARATOR)
        O.out(sql)
        O.out(pd.read_sql(sql, conn), is_df=True)

        # number of trips where passenger_count=0
        sql = text(f'''
        SELECT COUNT(*) AS trips_no_passenger FROM {table_name} WHERE passenger_count=0
        ''')
        O.out(SEPARATOR)
        O.out(sql)
        O.out(pd.read_sql(sql, conn), is_df=True)

        # number of trips where trip_distance=0
        sql = text(f'''
        SELECT COUNT(*) AS trips_zero_trip FROM {table_name} WHERE trip_distance=0
        ''')
        O.out(SEPARATOR)
        O.out(sql)
        O.out(pd.read_sql(sql, conn), is_df=True)
        
        # number of trips where total_amount=0
        sql = text(f'''
        SELECT COUNT(*) AS trips_zero_amount FROM {table_name} WHERE total_amount=0
        ''')
        O.out(SEPARATOR)
        O.out(sql)
        O.out(pd.read_sql(sql, conn), is_df=True)
        
        # number of trips where either passenger_count=0 or trip_distance=0 or total_amount=0
        sql = text(f'''
        SELECT COUNT(*) AS trips_union FROM {table_name} 
        WHERE passenger_count=0 OR trip_distance=0 OR total_amount=0
        ''')
        O.out(SEPARATOR)
        O.out(sql)
        O.out(pd.read_sql(sql, conn), is_df=True)
        
        # number of trips where passenger_count=0 and trip_distance=0 and total_amount=0
        sql = text(f'''
        SELECT COUNT(*) AS trips_intersection FROM {table_name} 
        WHERE passenger_count=0 AND trip_distance=0 AND total_amount=0
        ''')
        O.out(SEPARATOR)
        O.out(sql)
        O.out(pd.read_sql(sql, conn), is_df=True)

# runOtherStatistics(2023, GREEN)

In [None]:
def fixLocationBased(year, taxi_type): 
    dr.setTable(year, taxi_type)
    table_name = dr.getTableName()
    with dr.engn.connect() as conn:    
        # check if mta_tax and tolls_amount should be fixed. If too many rows have 0 value where 
        # passenger amount and distance are valid, it means that only some rows are filled, 
        # and it would not make sense to fix.
        # total trips
        sql = text(f"SELECT count(*) FROM {table_name}")
        O.out(SEPARATOR)
        O.out(sql)
        O.out(pd.read_sql(sql, conn), is_df=True)
        
        for c in ['mta_tax', 'tolls_amount']:
            # trips where mta_tax | tolls_amount is 0 but passenger amount and distance are valid
            sql = text(f"""
    SELECT pu_location_id, do_location_id, {c} FROM {table_name} 
    WHERE passenger_count>0 AND 
        total_amount > 0 AND 
        trip_distance > 0 AND 
        {c}=0
    """)
            O.out(sql)
            O.out(pd.read_sql(sql, conn), is_df=True)

        # fix zero trip and other similar columns
        target_columns = ['trip_distance', 'fare_amount', 'mta_tax']
        prefix = 'f_'

        O.out(SEPARATOR)
        O.out(f'fix data for {target_columns}')

        for c in target_columns:
            nc = prefix + c

            sql = text(f"ALTER TABLE {table_name} ADD COLUMN {nc} FLOAT DEFAULT 0.0")
            O.out(sql)
            try:
                conn.execute(sql)
                conn.commit()
                O.out(f'New column {nc} added')
            except Exception as e:
                O.out(SEPARATOR)
                O.out(f'Error: {e}')
                O.out(SEPARATOR)
                O.out(f'Rolling back...')
                conn.rollback()

            sql = text(f"""
    UPDATE {table_name} SET {nc}={c}
    """)
            O.out(SEPARATOR)
            O.out(sql)
            conn.execute(sql)
            conn.commit()
            O.out(f'Copied {c} -> {nc}')

            sql = text(f"""
    WITH D AS (
        SELECT pu_location_id as pu, do_location_id as do, AVG({c}) as avg
        FROM {table_name} WHERE {c} > 0 GROUP BY pu_location_id, do_location_id
    ) 
    UPDATE {table_name} SET {nc}=D.avg FROM D 
    WHERE pu_location_id=D.pu 
        AND do_location_id=D.do 
        AND {nc}=0
        """)
            O.out(sql)
            conn.execute(sql)
            conn.commit()
            O.out(f'Fixed missing {c}')

# O.clear()    
# fixLocationBased(2023, GREEN)

In [None]:
def fixTotalAmount(year, taxi_type):
    dr.setTable(year, taxi_type)
    table_name = dr.getTableName()
    with dr.engn.connect() as conn:    
        # fix total_amount by fare amount and other columns
        O.out(SEPARATOR)
        O.out(f'Fix total_amount by fare_amount and other related columns')
        # 'fare_amount', 'mta_tax', 'tip_amount', 'tolls_amount', 
        # 'congestion_surcharge', 'improvement_surcharge', 'extra',

        # F_TOTAL_AMOUNT
        sql = text(f"""
    ALTER TABLE {table_name} ADD COLUMN f_total_amount FLOAT DEFAULT 0.0
    """)
        try:
            O.out(SEPARATOR)
            O.out(sql)
            conn.execute(sql)
            conn.commit()
            O.out('add new column fta')
        except Exception as e:
            O.out(SEPARATOR)
            O.out(f'Error: {e}')
            O.out(SEPARATOR)
            O.out(f'Rolling back...')
            conn.rollback()

        sql = text(f"""
    UPDATE {table_name} SET f_total_amount=total_amount
    """)
        O.out(SEPARATOR)
        O.out(sql)
        conn.execute(sql)
        conn.commit()
        O.out('copy ta -> fta')

        fix_td_sql = text(f"""
    UPDATE {table_name} SET f_total_amount = (f_fare_amount + f_mta_tax + tip_amount + 
                                            tolls_amount + congestion_surcharge + 
                                            improvement_surcharge + extra)
    WHERE f_total_amount=0
    """)
        O.out(sql)
        conn.execute(fix_td_sql)
        conn.commit()
        O.out('fix missing ffa')

# O.clear()
# fixTotalAmount(2023, GREEN)

In [None]:
def fixPassengerCount(year, taxi_type):
    # fix passenger_count - set passenger count = 1 where it is 0 and...
    # trip_distance != 0 and total_amount != 0
    dr.setTable(year, taxi_type)
    table_name = dr.getTableName()

    O.out(SEPARATOR)
    O.out('Fix passenger_count')
    O.out(SEPARATOR)
    with dr.engn.connect() as conn:
        sql = text(f"""
    ALTER TABLE {table_name} ADD COLUMN f_passenger_count INT DEFAULT 0
    """)
        try:
            O.out(SEPARATOR)
            O.out(sql)
            conn.execute(sql)
            conn.commit()
            O.out('add new column fpc')
        except Exception as e:
            O.out(SEPARATOR)
            O.out(f'Error: {e}')
            O.out(SEPARATOR)
            O.out(f'Rolling back...')
            conn.rollback()

        sql = text(f"""
    UPDATE {table_name} SET f_passenger_count=passenger_count
    """)
        O.out(sql)
        conn.execute(sql)
        conn.commit()
        O.out('copy pc -> fpc')

        sql = text(f"""
    UPDATE {table_name} SET f_passenger_count=1
    WHERE passenger_count=0 
        AND f_trip_distance>0 
        AND f_total_amount>0
    """)
        O.out(sql)
        conn.execute(sql)
        conn.commit()
        O.out('fix missing ftd')

# O.clear()
# fixPassengerCount(2023, GREEN)

In [19]:
for t in TABLES:
    taxi_type, year  = t
    runBasicStatistics(year, taxi_type)
    runOtherStatistics(year, taxi_type)
    fixLocationBased(year, taxi_type)
    fixTotalAmount(year, taxi_type)
    fixPassengerCount(year, taxi_type)

OperationalError: (sqlite3.OperationalError) unrecognized token: "2020_taxi_tripsgreen"
[SQL: SELECT * FROM 2020_taxi_tripsgreen]
(Background on this error at: https://sqlalche.me/e/20/e3q8)