In [2]:
# Imports and constants
import time
import traceback
import pandas as pd
from sqlalchemy import MetaData, Table, Column, Integer, DateTime, Float
from sqlalchemy import create_engine, select, insert, event
from sqlalchemy.sql import text
from sqlalchemy.engine import URL
import numpy as np
import warnings

warnings.filterwarnings("ignore")


normalizedColumns = {
    'lpep_pickup_datetime': 'pickup_datetime', 'tpep_pickup_datetime': 'pickup_datetime',
    'lpep_dropoff_datetime': 'dropoff_datetime', 'tpep_dropoff_datetime': 'dropoff_datetime',
    'RatecodeID': 'ratecode_id',
    'PULocationID': 'pu_location_id', 
    'DOLocationID': 'do_location_id',
    'passenger_count': 'passenger_count', 
    'trip_distance': 'trip_distance', 
    'fare_amount': 'fare_amount', 
    'extra': 'extra', 
    'mta_tax': 'mta_tax',
    'tip_amount': 'tip_amount', 
    'tolls_amount': 'tolls_amount', 
    'improvement_surcharge': 'improvement_surcharge',
    'total_amount': 'total_amount', 
    'payment_type': 'payment_type', 
    'congestion_surcharge': 'congestion_surcharge'
}

# Functions
def getODBCString():
    SERVER = 'tcp:nyc-taxi-2024.database.windows.net,1433'
    DATABASE = 'nyc_taxi_2024'
    USERNAME = 'ishmakwana'
    PASSWORD = 'xxx'

    con_str = f'DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={SERVER};DATABASE={DATABASE};UID={USERNAME};PWD={PASSWORD};Encrypt=yes;TrustServerCertificate=no;Connection Timeout=30;'
    return URL.create("mssql+pyodbc", query={"odbc_connect": con_str})

def getSQLiteString():
    return 'sqlite:///db/taxi_db.db'

def getDateColumns(isGreen):
    return ['lpep_pickup_datetime','lpep_dropoff_datetime'] if isGreen == True else ['tpep_pickup_datetime','tpep_dropoff_datetime']


class TaxiDB:
    def __init__(self, src, yr, isGreen = True):
        self.md = MetaData()
        self.yr = yr
        self.g = isGreen
        self.src = src

        # self.engn = create_engine(getODBCString())
        self.engn = create_engine(getSQLiteString())
        self.md.reflect(self.engn)
        print('sql engine ready')

        with self.engn.connect() as conn:
            conn.rollback()

        # @event.listens_for(self.engn, 'before_cursor_execute')
        # def receive_before_cursor_execute(conn, cursor, statement, params, context, executemany):
        #     # print(f'params: {params}')
        #     if executemany:
        #         cursor.fast_executemany = True

        # self.createTable()
        # print(f'table created: {self.getTableName()}')
        
        self.table = self.md.tables[self.getTableName()]

    def describeTables(self):
        with self.engn.connect() as conn:
            for t in self.md.tables.keys():
                result = conn.execute(text(f'select COUNT(1) from {self.md.tables[t]}'))
                print(f'table name: {t}, #rows: {result.scalar()}')
    
    def printRowCount(self):
        with self.engn.connect() as conn:
            result = conn.execute(text(f'select COUNT(1) from {self.getTableName()}'))
            print(result.scalar())

    def prepareData(self):
        fc = []
        for c in self.df.columns:
            if c in normalizedColumns.keys():
                fc.append(normalizedColumns[c])
        self.df.rename(columns=normalizedColumns, inplace=True)
        self.df = self.df[fc]

        self.df.fillna(0, inplace=True)
        for ic in ['ratecode_id','pu_location_id','do_location_id','passenger_count','payment_type']:
            self.df[ic] = self.df[ic].astype(int)

        print(f'columns: {list(self.df.columns)}, #rows: {len(self.df)}')
        
    def getTableName(self):
        colr = 'green' if self.g == True else 'yellow'
        return f'{colr}_taxi_trips{self.yr}'

    def createTable(self):
        Table(
            self.getTableName(), self.md, 
            Column('id', Integer, primary_key = True), 
            Column('pickup_datetime', DateTime), 
            Column('dropoff_datetime', DateTime), 
            Column('ratecode_id', Integer), 
            Column('pu_location_id', Integer), 
            Column('do_location_id', Integer), 
            Column('passenger_count', Integer), 
            Column('trip_distance', Float), 
            Column('fare_amount', Float), 
            Column('extra', Float), 
            Column('mta_tax', Float), 
            Column('tip_amount', Float), 
            Column('tolls_amount', Float), 
            Column('improvement_surcharge', Float), 
            Column('total_amount', Float), 
            Column('payment_type', Integer), 
            Column('congestion_surcharge', Float),
            # extend_existing=True,
            keep_existing=True
        )

        self.md.create_all(self.engn)

    def dropTableByName(self, name):
        if name in self.md.tables.keys():
            self.dropTable(self.md.tables[name])

    def dropTable(self, table):
        with self.engn.connect() as conn:
            # result = conn.execute(table.select())
            # result.fetchall()
            # conn.close()

            try:
                table.drop(self.engn)
                conn.commit()
                print('table dropped')
            except Exception as e:
                print(f'could not drop table: {traceback.format_exc()}')
                conn.rollback()
    
    def uploadTable(self, chunksize = 2000000):
        s = time.time()

        inserted = 0
        with pd.read_csv(self.src, 
                         chunksize=chunksize, 
                         parse_dates=getDateColumns(self.g),
                         date_format="%m/%d/%Y %I:%M:%S %p") as reader:
            for df in reader:
                self.df = df
                self.prepareData()
                inserted += len(self.df)
                
                # if inserted <= 28000000:
                #     print(f'{inserted} records skipped')
                #     continue
                
                with self.engn.connect() as conn:
                    records = self.df.to_dict('records')
                    print(f'{inserted} records processed')
                    try:
                        conn.execute(insert(self.table), records)
                        conn.commit()
                        print('committed')
                    except Exception as e:
                        print(f'commit failed: {traceback.format_exc()}')
                        conn.rollback()
                        print('rolled back')
        
        print(f'upload complete, time taken: {time.time() - s} seconds')



In [3]:
# Playground 1

sources = [
            ('data/2023_Yellow_Taxi_Trip_Data.csv', 2023, False),
            ('data/2023_Green_Taxi_Trip_Data.csv', 2023, True),
            ('data/2022_Yellow_Taxi_Trip_Data.csv', 2022, False),
            ('data/2022_Green_Taxi_Trip_Data.csv', 2022, True),
            ('data/2021_Yellow_Taxi_Trip_Data.csv', 2021, False),
            ('data/2021_Green_Taxi_Trip_Data.csv', 2021, True),
            ('data/2020_Yellow_Taxi_Trip_Data.csv', 2020, False),
            ('data/2020_Green_Taxi_Trip_Data.csv', 2020, True)
            ]

# Setting the year and whether we're uploading green or yellow.
# This is needed because each csv goes into it's own table, 
#   and setting params like this will make accessing data from tables easier. 

# Connect to the SQL Database
for item in sources:
    src, year, green = item
    mydb = TaxiDB(src, year, green)
    # mydb.printRowCount()

    # Read from csv and insert into SQL table. 
    # Commented as data already uploaded, and we don't want to duplicate our data ;)
    # mydb.uploadTable()

    # Prints all the talbes in the DB, and # of rows in each table. 
    mydb.describeTables()
    break

    # Drops current table. 
    # Commented as we don't actually want to drop, and reuploading will cost time :(
    # mydb.dropTable(mydb.table)

sql engine ready
table name: green_taxi_trips2020, #rows: 1675896
table name: green_taxi_trips2021, #rows: 1011529
table name: green_taxi_trips2022, #rows: 808818
table name: green_taxi_trips2023, #rows: 761980
table name: taxi_zones, #rows: 263
table name: yellow_taxi_trips2020, #rows: 24014901
table name: yellow_taxi_trips2021, #rows: 29905605
table name: yellow_taxi_trips2022, #rows: 38133567
table name: yellow_taxi_trips2023, #rows: 36565008


In [None]:
engine = create_engine(getSQLiteString())
md = MetaData()
md.reflect(engine)
    
with engine.connect() as conn:
    sql = text('SELECT * FROM taxi_zones')
    df = pd.read_sql(sql, conn)

    print(df)

In [13]:
# create taxi zone table
source = 'data/taxi_zones.csv'

df = pd.read_csv(source)

og_zone_columns = ['Shape_Leng', 'the_geom', 'Shape_Area' , 'zone', 'LocationID', 'borough']
zone_column_map = {
    'Shape_Leng': 'zone_length', 
    'the_geom': 'zone_shape', 
    'Shape_Area': 'zone_area' , 
    'zone': 'zone', 
    'LocationID': 'location_id', 
    'borough': 'location_name'
    }

df = df[og_zone_columns]
df = df.rename(columns=zone_column_map)
df.info()

engine = create_engine(getSQLiteString())
md = MetaData()
md.reflect(engine)
    
with engine.connect() as conn:
    df.to_sql('taxi_zones', conn, if_exists='replace')
    conn.commit()

            

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 263 entries, 0 to 262
Data columns (total 6 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   zone_length    263 non-null    float64
 1   zone_shape     263 non-null    object 
 2   zone_area      263 non-null    float64
 3   zone           263 non-null    object 
 4   location_id    263 non-null    int64  
 5   location_name  263 non-null    object 
dtypes: float64(2), int64(1), object(3)
memory usage: 12.5+ KB
