In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sqlalchemy.sql import text
from lib import Output, TaxiDBReader, GREEN, YELLOW, TABLES, tabulate, CHUNK_SIZE, TABLE_FORMAT, SEPARATOR

O = Output('output/chart_output.txt') 
    
O.clear()
O.out(SEPARATOR)
O.out('Priliminary Viz (charts)')
O.out(SEPARATOR)
dr = TaxiDBReader()

# Example data (replace these with your actual data)
# distance = [1.2, 3.5, 7.8, 2.5, 4.6, 5.9, 6.1, 7.3, 2.2, 3.1]  # Distance data
# passenger_count = [1, 2, 2, 3, 1, 5, 1, 2, 1, 4]  # Passenger count data
# cost = [10.5, 15.7, 22.1, 12.8, 18.3, 24.0, 19.5, 21.2, 13.1, 16.8]  # Cost data


In [None]:
dr.setTable(year=2023, taxi_type=GREEN)
table_name = dr.getTableName()

def getDF(sql):
    with dr.engn.connect() as conn:
        return pd.read_sql(sql, conn)

sql = text(f'''
SELECT pu_location_id, do_location_id, f_trip_distance, f_passenger_count, f_total_amount, f_fare_amount, 
    (unixepoch(dropoff_datetime)-unixepoch(pickup_datetime)) as trip_duration
           FROM {table_name}
WHERE f_trip_distance > 0 AND f_trip_distance < 20 
    AND f_passenger_count > 0 
    AND f_total_amount > 0 AND f_total_amount <= 100
    AND f_fare_amount > 0
    AND trip_duration <= 7200
''')

df = getDF(sql)

In [None]:
# histograms

def get_cmap(n, name='hsv'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n)

columns = ['f_trip_distance', 'f_passenger_count', 'f_total_amount', 'f_fare_amount']
cmap = get_cmap(len(columns) + 1)

for i, c in enumerate(columns):
    plt.figure(figsize=(16,9))
    plt.hist(df[c], bins=10, range=(df[c].min(), df[c].max()), color=cmap(i))
    plt.xlabel(c)
    plt.ylabel('Frequency')
    plt.show()

    print(f'column: {c}, min: {df[c].min()}, max: {df[c].max()}')

In [None]:
# scatter plots

columns = [
    ('trip_duration', 'f_total_amount'), 
    ('f_trip_distance', 'f_total_amount'), 
    ]
for i, ctuple in enumerate(columns):
    plt.figure(figsize=(16,9))
    xc, yc = ctuple
    plt.scatter(df[xc], df[yc], color=cmap(i), alpha=0.7)
    plt.xlabel(xc)
    plt.ylabel(yc)
    plt.show()


In [None]:
# box plots 
columns = ['f_trip_distance', 'f_passenger_count', 'f_total_amount', 'f_fare_amount', 'trip_duration']
cmap = get_cmap(len(columns) + 1)

box_plot_data = []
for i, c in enumerate(columns):

    plt.figure(figsize=(3,6))
    box_plot = plt.boxplot(df[c], patch_artist=True, labels=[c])

    for patch in box_plot['boxes']:
        patch.set_facecolor(cmap(i))

    plt.grid()
    plt.show()
