In [28]:
# Imports and constants
import pandas as pd
from sqlalchemy import MetaData, Table, Column, Integer, DateTime, Float
from sqlalchemy import create_engine, select, insert
from sqlalchemy.engine import URL
import numpy as np

# src = "data/2020_Green_Taxi_Trip_Data.csv"
# src = "data/2020_Yellow_Taxi_Trip_Data.csv"
# df = pd.read_csv(src, parse_dates=['lpep_pickup_datetime','lpep_dropoff_datetime'], date_format="%m/%d/%Y %I:%M:%S %p")

normalizedColumns = {
    'lpep_pickup_datetime': 'pickup_datetime', 'tpep_pickup_datetime': 'pickup_datetime',
    'lpep_dropoff_datetime': 'dropoff_datetime', 'tpep_dropoff_datetime': 'dropoff_datetime',
    'RatecodeID': 'ratecode_id',
    'PULocationID': 'pu_location_id', 
    'DOLocationID': 'do_location_id',
    'passenger_count': 'passenger_count', 
    'trip_distance': 'trip_distance', 
    'fare_amount': 'fare_amount', 
    'extra': 'extra', 
    'mta_tax': 'mta_tax',
    'tip_amount': 'tip_amount', 
    'tolls_amount': 'tolls_amount', 
    'improvement_surcharge': 'improvement_surcharge',
    'total_amount': 'total_amount', 
    'payment_type': 'payment_type', 
    'congestion_surcharge': 'congestion_surcharge'
}

In [29]:
# Functions
def getConnectionString():
    SERVER = 'tcp:nyc-taxi-2024.database.windows.net,1433'
    DATABASE = 'nyc_taxi_2024'
    USERNAME = 'ishmakwana'
    # PASSWORD = xxx

    return f'DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={SERVER};DATABASE={DATABASE};UID={USERNAME};PWD={PASSWORD};Encrypt=yes;TrustServerCertificate=no;Connection Timeout=30;'

def getDateColumns(isGreen):
    return ['lpep_pickup_datetime','lpep_dropoff_datetime'] if isGreen == True else ['tpep_pickup_datetime','tpep_dropoff_datetime']

def makeValues(source, index):
    values = {'year': 2020}
    for col in source.columns:
        if col in normalizedColumns.keys():
            if(normalizedColumns[col] != 'pickup_datetime' and normalizedColumns[col] != 'dropoff_datetime' and np.isnan(source[col][index])):
                values[normalizedColumns[col]] = 0
            else:
                if col == 'RatecodeID' or col == 'PULocationID' or col == 'DOLocationID':
                    values[normalizedColumns[col]] = int(source[col][index])
                else:
                    values[normalizedColumns[col]] = source[col][index]

    return values

def makeRecods(source, start, count):
    r = []
    for i in range(start, start+count):
        r.append(makeValues(source, i))

    return r


class TaxiDB:
    def __init__(self, src, yr, isGreen = True):
        self.md = MetaData()
        self.yr = yr
        self.g = isGreen
        self.src = src

        self.engn = create_engine(URL.create("mssql+pyodbc", query={"odbc_connect": getConnectionString()}))
        self.md.reflect(self.engn)

        print('sql engine ready')

        self.createTable()

        print('table created')
        
        self.table = self.md.tables[self.getTableName()]

    def describeTables(self):
        with self.engn.connect() as conn:
            for t in self.md.tables.keys():
                result = conn.execute(select(self.md.tables[t]))
                print(f'table name: {t}, #rows: {len(result.all())}')

    def readSrc(self):
        self.df = pd.read_csv(self.src, \
                              parse_dates=getDateColumns(self.g), \
                                date_format="%m/%d/%Y %I:%M:%S %p")
        print('csv reading complete')
        
    def getTableName(self):
        colr = 'green' if self.g == True else 'yellow'
        return f'{colr}_taxi_trips{self.yr}'

    def createTable(self):
        Table(
            self.getTableName(), self.md, 
            Column('id', Integer, primary_key = True), 
            Column('year', Integer), 
            Column('pickup_datetime', DateTime), 
            Column('dropoff_datetime', DateTime), 
            Column('ratecode_id', Integer), 
            Column('pu_location_id', Integer), 
            Column('do_location_id', Integer), 
            Column('passenger_count', Integer), 
            Column('trip_distance', Float), 
            Column('fare_amount', Float), 
            Column('extra', Float), 
            Column('mta_tax', Float), 
            Column('tip_amount', Float), 
            Column('tolls_amount', Float), 
            Column('improvement_surcharge', Float), 
            Column('total_amount', Float), 
            Column('payment_type', Integer), 
            Column('congestion_surcharge', Float),
            extend_existing=True
        )

        self.md.create_all(self.engn)

    def dropTableByName(self, name):
        if name in self.md.tables.keys():
            self.dropTable(self.md.tables[name])

    def dropTable(self, table):
        with self.engn.connect() as conn:
            result = conn.execute(table.select())

            # fully read result sets
            result.fetchall()

            # close connections
            conn.close()

            # now locks are removed
            try:
                table.drop(self.engn)
                conn.commit()
                print('table dropped')
            except Exception as e:
                print(f'could not drop table: {e}')
                conn.rollback()
    
    def uploadTable(self, s, t, chnk = 10000):
        with self.engn.connect() as conn:
            START = s
            TOTAL = t
            CHUNK_SIZE = chnk
            for i in range(START, TOTAL, CHUNK_SIZE):
                ACTUAL_CHUNK = min(CHUNK_SIZE, TOTAL - i)
                records = makeRecods(df, i, ACTUAL_CHUNK)
                print(f'{i+ACTUAL_CHUNK} records processed')
                try:
                    result = conn.execute(insert(self.table), records)
                    conn.commit()
                except Exception as e:
                    print('commit failed: {e}')
                    conn.rollback()

    def printRowCount(self):
        with self.engn.connect() as conn:
            result = conn.execute(select(self.table))
            print(len(result.all()))


In [30]:
# Playground 1
src = 'data/2023_Green_Taxi_Trip_Data.csv'
year = 2023
green = True

mydb = TaxiDB(src, year, green)
# mydb.readSrc()
# mydb.describeTables()

# start, total, chunk = 0, len(mydb.df), 10000
# print(f'total to upload={total}')
# mydb.uploadTable(start, total, chunk)
# mydb.dropTable(mydb.table)
# mydb.describeTables()

sql engine ready
table created
table dropped
table dropped
