In [1]:
import pandas as pd
import requests as rq
import os

In [2]:
from dataclasses import dataclass

@dataclass
class CliParams:
    year: int = 2021
    month: int = 1
    cache_dir: str = "cache"
    chunck_size: int | None = 100_000

params = CliParams()

In [3]:
URL_PREFIX = "https://github.com/DataTalksClub/nyc-tlc-data/releases/download/yellow/"
file_path = f"yellow_tripdata_{params.year}-{params.month:02d}.csv.gz"


In [4]:
async def is_file_cached(file_name: str) -> bool:
    '''Checks whether the file_name exists in cache dir or not
    '''
    # check if cache directory exists
    if not os.path.exists(params.cache_dir):
        return False
    
    # look for file name in cache directory
    for f in os.listdir(params.cache_dir):
        if file_name == f:
            return True

    # a file with the given name not found
    return False

In [5]:
async def download_file(file_name: str) -> None:
    if await is_file_cached(file_name):
        print(f"found file: {file_name} in cache. skipping download")
        return 

    resp = rq.get(URL_PREFIX + file_name)

    if resp.status_code >= 300 or resp.status_code < 200:
        print(f"Status code {resp.status_code} received, expected a 2xx code. aborting download.")
        return

    os.makedirs(params.cache_dir, exist_ok=True)
    with open(params.cache_dir + '/' + file_name, 'wb') as f :
        f.write(resp.content)


In [6]:
await download_file(file_name=file_path)

found file: yellow_tripdata_2021-01.csv.gz in cache. skipping download


In [7]:
df = pd.read_csv(params.cache_dir + '/' + file_path, nrows=params.chunck_size)
df.head()

Unnamed: 0,VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge
0,1,2021-01-01 00:30:10,2021-01-01 00:36:12,1,2.1,1,N,142,43,2,8.0,3.0,0.5,0.0,0.0,0.3,11.8,2.5
1,1,2021-01-01 00:51:20,2021-01-01 00:52:19,1,0.2,1,N,238,151,2,3.0,0.5,0.5,0.0,0.0,0.3,4.3,0.0
2,1,2021-01-01 00:43:30,2021-01-01 01:11:06,1,14.7,1,N,132,165,1,42.0,0.5,0.5,8.65,0.0,0.3,51.95,0.0
3,1,2021-01-01 00:15:48,2021-01-01 00:31:01,0,10.6,1,N,138,132,1,29.0,0.5,0.5,6.05,0.0,0.3,36.35,0.0
4,2,2021-01-01 00:31:49,2021-01-01 00:48:21,1,4.94,1,N,68,33,1,16.5,0.5,0.5,4.06,0.0,0.3,24.36,2.5


In [8]:
df.dtypes

VendorID                   int64
tpep_pickup_datetime      object
tpep_dropoff_datetime     object
passenger_count            int64
trip_distance            float64
RatecodeID                 int64
store_and_fwd_flag        object
PULocationID               int64
DOLocationID               int64
payment_type               int64
fare_amount              float64
extra                    float64
mta_tax                  float64
tip_amount               float64
tolls_amount             float64
improvement_surcharge    float64
total_amount             float64
congestion_surcharge     float64
dtype: object

In [9]:
df.shape

(100000, 18)

In [10]:
dtype = {
    "VendorID": "Int64",
    "passenger_count": "Int64",
    "trip_distance": "float64",
    "RatecodeID": "Int64",
    "store_and_fwd_flag": "string",
    "PULocationID": "Int64",
    "DOLocationID": "Int64",
    "payment_type": "Int64",
    "fare_amount": "float64",
    "extra": "float64",
    "mta_tax": "float64",
    "tip_amount": "float64",
    "tolls_amount": "float64",
    "improvement_surcharge": "float64",
    "total_amount": "float64",
    "congestion_surcharge": "float64"
}

parse_dates = [
    "tpep_pickup_datetime",
    "tpep_dropoff_datetime"
]

df = pd.read_csv(
    params.cache_dir + '/' + file_path,
    nrows=params.chunck_size,
    dtype=dtype,
    parse_dates=parse_dates
)

In [11]:
df.head()

Unnamed: 0,VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge
0,1,2021-01-01 00:30:10,2021-01-01 00:36:12,1,2.1,1,N,142,43,2,8.0,3.0,0.5,0.0,0.0,0.3,11.8,2.5
1,1,2021-01-01 00:51:20,2021-01-01 00:52:19,1,0.2,1,N,238,151,2,3.0,0.5,0.5,0.0,0.0,0.3,4.3,0.0
2,1,2021-01-01 00:43:30,2021-01-01 01:11:06,1,14.7,1,N,132,165,1,42.0,0.5,0.5,8.65,0.0,0.3,51.95,0.0
3,1,2021-01-01 00:15:48,2021-01-01 00:31:01,0,10.6,1,N,138,132,1,29.0,0.5,0.5,6.05,0.0,0.3,36.35,0.0
4,2,2021-01-01 00:31:49,2021-01-01 00:48:21,1,4.94,1,N,68,33,1,16.5,0.5,0.5,4.06,0.0,0.3,24.36,2.5


In [12]:
df.dtypes

VendorID                          Int64
tpep_pickup_datetime     datetime64[ns]
tpep_dropoff_datetime    datetime64[ns]
passenger_count                   Int64
trip_distance                   float64
RatecodeID                        Int64
store_and_fwd_flag       string[python]
PULocationID                      Int64
DOLocationID                      Int64
payment_type                      Int64
fare_amount                     float64
extra                           float64
mta_tax                         float64
tip_amount                      float64
tolls_amount                    float64
improvement_surcharge           float64
total_amount                    float64
congestion_surcharge            float64
dtype: object

In [13]:
df.shape

(100000, 18)

In [14]:
from cassandra.cluster import Cluster, ResultSet

CASSANDRA_HOST = '127.0.0.1'
CASSANDRA_PORT = 9042
CASSANDRA_KEYSPACE = 'dev'

cluster = Cluster(
    [CASSANDRA_HOST], 
    port=CASSANDRA_PORT, 
)
session = cluster.connect()

In [15]:
session.execute(f"CREATE KEYSPACE IF NOT EXISTS {CASSANDRA_KEYSPACE} WITH replication = "
                "{'class': 'SimpleStrategy', 'replication_factor': 1};"
)

<cassandra.cluster.ResultSet at 0x7fa97401cec0>

In [16]:
TABLE_NAME = 'yellow'

session.set_keyspace(CASSANDRA_KEYSPACE);

In [17]:
from pandas.io import sql

TABLE_NAME = 'yellow'
pandas_gen_table: str = sql.get_schema(df, name=TABLE_NAME)

table_definition = pandas_gen_table.replace('REAL', 'FLOAT').replace('INTEGER', 'INT').replace(')', ', id UUID PRIMARY KEY\n)')
print(table_definition)

CREATE TABLE "yellow" (
"VendorID" INT,
  "tpep_pickup_datetime" TIMESTAMP,
  "tpep_dropoff_datetime" TIMESTAMP,
  "passenger_count" INT,
  "trip_distance" FLOAT,
  "RatecodeID" INT,
  "store_and_fwd_flag" TEXT,
  "PULocationID" INT,
  "DOLocationID" INT,
  "payment_type" INT,
  "fare_amount" FLOAT,
  "extra" FLOAT,
  "mta_tax" FLOAT,
  "tip_amount" FLOAT,
  "tolls_amount" FLOAT,
  "improvement_surcharge" FLOAT,
  "total_amount" FLOAT,
  "congestion_surcharge" FLOAT
, id UUID PRIMARY KEY
)


In [18]:
session.execute(f'DROP TABLE IF EXISTS {TABLE_NAME}')
session.execute(table_definition)

<cassandra.cluster.ResultSet at 0x7fa973e13bb0>

Now let's prepare the insert statement,

Note: the column names should be in double quotes to ensure case changes don't cause any errors
(maybe a better approach is to normalize all column names from the beginning)

In [19]:
query_parameterized = f'INSERT INTO {TABLE_NAME} (\n{'\n, '.join(map(lambda c: f'"{c}"', df.columns))}\n, "id"\n) VALUES ' + \
    f'(\n{', '.join(['?'] * df.shape[1])}, uuid()\n)'
print(query_parameterized)

INSERT INTO yellow (
"VendorID"
, "tpep_pickup_datetime"
, "tpep_dropoff_datetime"
, "passenger_count"
, "trip_distance"
, "RatecodeID"
, "store_and_fwd_flag"
, "PULocationID"
, "DOLocationID"
, "payment_type"
, "fare_amount"
, "extra"
, "mta_tax"
, "tip_amount"
, "tolls_amount"
, "improvement_surcharge"
, "total_amount"
, "congestion_surcharge"
, "id"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, uuid()
)


In [20]:
perpared = session.prepare(query_parameterized)

One important note with cassandra: inserting duplicate primary keys actually works like an UPSERT rather than fail

In [21]:
from tqdm import tqdm

results = session.execute_concurrent(((perpared, r.values) for _, r in df.iterrows()), results_generator=True)

for res in tqdm(results):
    if not res.success:
        print(res)

100000it [00:37, 2692.73it/s]


In [22]:
result: ResultSet = session.execute(f'SELECT COUNT(*) FROM "{TABLE_NAME}"')
result.all()

[Row(count=100000)]

In [23]:
session.execute(f'DROP KEYSPACE {CASSANDRA_KEYSPACE}')
session.shutdown()