In [3]:
!pip install pymongo pyarrow python-dateutil


Collecting pymongo
  Downloading pymongo-4.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (22 kB)
Collecting dnspython<3.0.0,>=1.16.0 (from pymongo)
  Downloading dnspython-2.7.0-py3-none-any.whl.metadata (5.8 kB)
Downloading pymongo-4.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading dnspython-2.7.0-py3-none-any.whl (313 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m313.6/313.6 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hInstalling collected packages: dnspython, pymongo
Successfully installed dnspython-2.7.0 pymongo-4.14.0


In [1]:
import os
USER = "mongouser"
PWD  = "mongopassword"
HOST = "mongodb"
DB   = "taxi_logs"   # your target DB for writes

# Try admin as the auth DB (common when the root user was created via MONGO_INITDB_* envs)
os.environ["MONGO_URL"] = f"mongodb://{USER}:{PWD}@{HOST}:27017/{DB}?authSource=admin"
print(os.environ["MONGO_URL"])

mongodb://mongouser:mongopassword@mongodb:27017/taxi_logs?authSource=admin


In [4]:
from pymongo import MongoClient
client = MongoClient(os.environ["MONGO_URL"])
print(client.admin.command("ping"))  # should print {'ok': 1.0}

{'ok': 1.0}


In [None]:
# parquet_to_mongo.py
import os
from datetime import timezone
from pymongo import MongoClient, InsertOne
import pyarrow.dataset as ds
import pyarrow as pa
import pyarrow.compute as pc
from dateutil import parser as dtparse
import re

os.environ["PARQUET_ROOT"] = "/home/jovyan/work/data/nyc-taxi/partitioned/year=2019"

# If Mongo has no auth:
os.environ["MONGO_URL"] = "mongodb://mongouser:mongopassword@mongodb:27017/taxi_logs?authSource=admin"


MONGO_URL = os.getenv("MONGO_URL", "mongodb://localhost:27017/")
PARQUET_ROOT = os.getenv("PARQUET_ROOT", "/data/taxi")  # e.g., mounted host path or HDFS copy to local
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "5000"))

client = MongoClient(MONGO_URL)
db = client["taxi_logs"]
col = db["trips"]

# Define a tiny transformer from an Arrow batch to list[dict]
def batch_to_docs(batch: pa.RecordBatch, partition_values: dict[str, str] | None = None):
    # Columns expected in NYC taxi parquet
    cols = {name: batch.column(i) for i, name in enumerate(batch.schema.names)}
    # Some datasets have timestamps as strings; some as timestamp types. Normalize to ISO strings then to datetimes.
    pick_col = cols.get("tpep_pickup_datetime")
    drop_col = cols.get("tpep_dropoff_datetime")
    pu_col   = cols.get("PULocationID")
    do_col   = cols.get("DOLocationID")
    pc_col   = cols.get("passenger_count")
    fare_col = cols.get("fare_amount")

    n = batch.num_rows
    docs = []
    for i in range(n):
        # Extract values safely (handle nulls)
        # New:
        def get_value(col, i):
            # Returns None automatically for nulls
            if col is None:
                return None
            return col[i].as_py()
        
        def to_dt_utc(x):
            # Normalize to timezone-aware UTC
            if x is None:
                return None
            # x can be a Python datetime (from Arrow) or a string
            from datetime import datetime, timezone
            if isinstance(x, str):
                from dateutil import parser as dtparse
                dt = dtparse.parse(x)
            else:
                dt = x
            if getattr(dt, "tzinfo", None) is None:
                dt = dt.replace(tzinfo=timezone.utc)
            else:
                dt = dt.astimezone(timezone.utc)
            return dt

        doc = {
            "pickup": {
                "time": to_dt_utc(get_value(pick_col, i)),
                "location_id": get_value(pu_col, i)
            },
            "dropoff": {
                "time": to_dt_utc(get_value(drop_col, i)),
                "location_id": get_value(do_col, i)
            },
            "passenger_count": get_value(pc_col, i),
            "fare_amount": get_value(fare_col, i),
        }


        # Optional: attach partition info for traceability (e.g., year/month)
        if partition_values:
            doc.setdefault("meta", {})["partition"] = partition_values

        docs.append(doc)
    return docs

def main():
    dataset = ds.dataset(PARQUET_ROOT, format="parquet", partitioning="hive")  # understands year=2019/month=01/...

    # Iterate by fragments (files/partitions) → smaller Arrow tables → record batches
    for fragment in dataset.get_fragments():
        # Extract year/month from a path like .../year=2019/month=01/....
        m = re.search(r"year=(\d{4})/month=(\d{2})", fragment.path)
        part_vals = {"year": m.group(1), "month": m.group(2)} if m else None
    
        table = fragment.to_table()  # or .to_table(columns=[...]) to project fewer cols
        for batch in table.to_batches(max_chunksize=BATCH_SIZE):
            docs = batch_to_docs(batch, partition_values=part_vals)
            if docs:
                col.bulk_write([InsertOne(d) for d in docs], ordered=False)
                print(f"Inserted {len(docs)} docs from {fragment.path}")

if __name__ == "__main__":
    main()


In [4]:
# import os, glob, itertools

# BASES = [
#     "/home/jovyan/work/data",
#     "/data/taxi",                          # in case you also mounted this
# ]

# print("Checking common bases…")
# for b in BASES:
#     print(b, "exists?" , os.path.exists(b))

# # Find any parquet files (limit output)
# candidates = []
# for base in BASES:
#     for p in itertools.islice(glob.iglob(base + "/**/*.parquet", recursive=True), 50):
#         candidates.append(p)
# print(f"\nFound {len(candidates)} parquet files (showing up to 50):")
# for p in candidates[:50]:
#     print(" -", p)


In [7]:
#!/usr/bin/env python3
from datetime import datetime, timedelta, timezone
from pymongo import MongoClient
from pymongo.errors import DuplicateKeyError, AutoReconnect
import time

# ------------ toggles ------------
SKIP_CREATE_TIME_INDEX = True  # build at the very end instead
USE_MANUAL_RANGE = True        # start small to see output quickly
MANUAL_MIN = datetime(2019, 1, 1, tzinfo=timezone.utc)
MANUAL_MAX = datetime(2019, 1, 2, tzinfo=timezone.utc)  # just 1 day
# MANUAL_MIN = datetime(2019, 1, 1, tzinfo=timezone.utc)
# MANUAL_MAX = datetime(2019, 2, 1, tzinfo=timezone.utc)  # one month
CREATE_UNIQUE_INDEX = True
# ---------------------------------

MONGO_URL = "mongodb://mongouser:mongopassword@mongodb:27017/taxi_logs?authSource=admin"
DB_NAME, COLL_NAME = "taxi_logs", "trips"

print("Connecting…", flush=True)
client = MongoClient(
    MONGO_URL,
    serverSelectionTimeoutMS=60000,
    connectTimeoutMS=60000,
    socketTimeoutMS=0,
    retryWrites=True,
    retryReads=True
)
col = client[DB_NAME][COLL_NAME]

def with_retries(fn, what, attempts=5):
    for i in range(attempts):
        try:
            return fn()
        except AutoReconnect as e:
            sleep = 1.5 * (2 ** i)
            print(f"⚠️  AutoReconnect during {what}. Retry {i+1}/{attempts} in {sleep:.1f}s …", flush=True)
            time.sleep(sleep)
    # last try without catching to surface the error
    return fn()

def utc(dt):  # ensure tz-aware
    return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)

def month_bounds(start_dt, end_dt):
    start = utc(datetime(start_dt.year, start_dt.month, 1))
    cur = start
    while cur < end_dt:
        nxt = utc(datetime(cur.year + 1, 1, 1)) if cur.month == 12 else utc(datetime(cur.year, cur.month + 1, 1))
        yield cur, min(nxt, end_dt)
        cur = nxt

def get_time_range():
    print("Computing min/max pickup.time (can be slow on large collections)…", flush=True)
    pipe = [
        {"$match": {"pickup.time": {"$type": "date"}}},
        {"$group": {"_id": None, "minT": {"$min": "$pickup.time"}, "maxT": {"$max": "$pickup.time"}}}
    ]
    agg = with_retries(lambda: list(col.aggregate(pipe, allowDiskUse=True)), "aggregate(min/max)")
    if not agg or agg[0]["minT"] is None or agg[0]["maxT"] is None:
        raise SystemExit("No documents with pickup.time as a date were found.")
    # make max exclusive by adding 1 day to ensure final month closes
    return agg[0]["minT"], agg[0]["maxT"] + timedelta(days=1)

def count_month(start, end):
    return with_retries(lambda: col.count_documents({"pickup.time": {"$gte": start, "$lt": end}}),
                        "count_documents(month)")

def count_missing_time_bucket(start, end):
    q = {"pickup.time": {"$gte": start, "$lt": end, "$type": "date"},
         "meta.time_bucket": {"$exists": False}}
    return with_retries(lambda: col.count_documents(q), "count TB_missing")

def count_missing_trip_id(start, end):
    q = {"pickup.time": {"$gte": start, "$lt": end}, "trip_id": {"$exists": False}}
    return with_retries(lambda: col.count_documents(q), "count ID_missing")

def backfill_time_bucket(start, end):
    filt = {"pickup.time": {"$gte": start, "$lt": end, "$type": "date"},
            "meta.time_bucket": {"$exists": False}}
    t0 = time.perf_counter()
    res = with_retries(lambda: col.update_many(
        filt,
        [{"$set": {"meta.time_bucket": {"$dateTrunc": {"date": "$pickup.time", "unit": "hour"}}}}]
    ), "update_many(time_bucket)")
    return res.matched_count, res.modified_count, time.perf_counter() - t0

def backfill_trip_id(start, end):
    filt = {"pickup.time": {"$gte": start, "$lt": end}, "trip_id": {"$exists": False}}
    t0 = time.perf_counter()
    res = with_retries(lambda: col.update_many(
        filt,
        [{"$set": {"trip_id": {
            "$concat": [
                {"$dateToString": {"date": "$pickup.time", "format": "%Y-%m-%dT%H:%M:%S.%LZ", "timezone": "UTC"}},
                ":", {"$toString": {"$ifNull": ["$pickup.location_id", "na"]}},
                ":", {"$toString": {"$ifNull": ["$dropoff.location_id", "na"]}},
                ":", {"$toString": {"$ifNull": ["$fare_amount", "na"]}}
            ]}}}]
    ), "update_many(trip_id)")
    return res.matched_count, res.modified_count, time.perf_counter() - t0

def main():
    # 0) (Optional) defer index until the end so we see output sooner
    if not SKIP_CREATE_TIME_INDEX:
        print("Creating index on pickup.time …", flush=True)
        with_retries(lambda: col.create_index([("pickup.time", 1)], name="idx_pickup_time"),
                     "create_index(pickup.time)")
        print("Index on pickup.time created.", flush=True)

    # 1) Determine window(s)
    if USE_MANUAL_RANGE:
        minT, maxT = MANUAL_MIN, MANUAL_MAX
        print(f"Using manual window: {minT.isoformat()} → {maxT.isoformat()}", flush=True)
    else:
        minT, maxT = get_time_range()
        print(f"Discovered window:   {minT.isoformat()} → {maxT.isoformat()}", flush=True)

    print("\n=== Month-by-month Backfill (by pickup.time) ===", flush=True)
    print("time_bucket → rounds pickup.time to the hour for faster grouping", flush=True)
    print("trip_id     → synthetic unique key to prevent duplicates\n", flush=True)
    print(f"{'MonthStart(UTC)':<20}{'Total':>10} | "
          f"{'TB_missing':>11} -> {'TB_mod':>7} ({'s':>5}) | "
          f"{'ID_missing':>11} -> {'ID_mod':>7} ({'s':>5})", flush=True)

    grand_total = grand_tb_mod = grand_id_mod = 0
    job_t0 = time.perf_counter()

    for mstart, mend in month_bounds(minT, maxT):
        total = count_month(mstart, mend)
        grand_total += total
        if total == 0:
            print(f"{mstart.isoformat():<20}{0:>10} | {0:>11} -> {0:>7} ({0:>5}) | {0:>11} -> {0:>7} ({0:>5})",
                  flush=True)
            continue

        tb_miss = count_missing_time_bucket(mstart, mend)
        tb_matched, tb_mod, tb_s = backfill_time_bucket(mstart, mend)

        id_miss = count_missing_trip_id(mstart, mend)
        id_matched, id_mod, id_s = backfill_trip_id(mstart, mend)

        grand_tb_mod += tb_mod
        grand_id_mod += id_mod

        print(f"{mstart.isoformat():<20}{total:>10} | "
              f"{tb_miss:>11} -> {tb_mod:>7} ({tb_s:>5.1f}) | "
              f"{id_miss:>11} -> {id_mod:>7} ({id_s:>5.1f})",
              flush=True)

    # 2) Build unique index at the end (with retry)
    if CREATE_UNIQUE_INDEX:
        print("\nCreating partial unique index on trip_id …", flush=True)
        try:
            with_retries(lambda: col.create_index(
                [("trip_id", 1)],
                unique=True,
                partialFilterExpression={"trip_id": {"$type": "string"}},
                name="uniq_trip_id_partial"
            ), "create_index(trip_id unique)")
            print("Unique index created on trip_id (partial).", flush=True)
        except DuplicateKeyError as e:
            print("⚠️  DuplicateKeyError while creating unique index. Investigate duplicates.", flush=True)
            print(e, flush=True)

    elapsed = time.perf_counter() - job_t0
    print("\n=== Summary ===", flush=True)
    print(f"Docs scanned (sum of month totals): {grand_total:,}", flush=True)
    print(f"time_bucket modified: {grand_tb_mod:,}", flush=True)
    print(f"trip_id modified:    {grand_id_mod:,}", flush=True)
    print(f"Elapsed: {elapsed/60:.1f} minutes", flush=True)

main()


Connecting…
Using manual window: 2019-01-01T00:00:00+00:00 → 2019-01-02T00:00:00+00:00

=== Month-by-month Backfill (by pickup.time) ===
time_bucket → rounds pickup.time to the hour for faster grouping
trip_id     → synthetic unique key to prevent duplicates

MonthStart(UTC)          Total |  TB_missing ->  TB_mod (    s) |  ID_missing ->  ID_mod (    s)
2019-01-01T00:00:00+00:00    189432 |           0 ->       0 (  0.4) |      189432 ->  189432 (  7.4)

Creating partial unique index on trip_id …
⚠️  DuplicateKeyError while creating unique index. Investigate duplicates.
Index build failed: 26362b9a-1aac-40ee-82e7-93011b855bee: Collection taxi_logs.trips ( 81f13904-2100-4505-8de9-a4728b198a32 ) :: caused by :: E11000 duplicate key error collection: taxi_logs.trips index: uniq_trip_id_partial dup key: { trip_id: "2019-01-01T00:25:12.000Z:141:263:5" }, full error: {'ok': 0.0, 'errmsg': 'Index build failed: 26362b9a-1aac-40ee-82e7-93011b855bee: Collection taxi_logs.trips ( 81f13904-2100-4

In [None]:
from pymongo import MongoClient, ASCENDING, GEOSPHERE
import time
import sys

def progress_bar(step, total_steps, label=""):
    bar_len = 30
    filled = int(round(bar_len * step / float(total_steps)))
    bar = "█" * filled + "-" * (bar_len - filled)
    sys.stdout.write(f"\r[{bar}] {step}/{total_steps} {label}")
    sys.stdout.flush()

client = MongoClient("mongodb://mongouser:mongopassword@mongodb:27017/taxi_logs?authSource=admin")
c = client.taxi_logs.trips

indexes = [
    (("pickup.time", ASCENDING), "pickup.time"),
    (("dropoff.time", ASCENDING), "dropoff.time"),
    (("pickup.location_id", ASCENDING), "pickup.location_id + pickup.time"),
    (("dropoff.location_id", ASCENDING), "dropoff.location_id + dropoff.time"),
    (("meta.time_bucket", ASCENDING), "meta.time_bucket"),
    (("pickup.loc", GEOSPHERE), "pickup.loc (geospatial)")
]

total = len(indexes)
for i, (spec, label) in enumerate(indexes, start=1):
    # Mongo expects a list of tuples for compound indexes
    if isinstance(spec, tuple):
        spec = [spec]

    c.create_index(spec)
    progress_bar(i, total, label)
    time.sleep(0.5)  # just to make the bar visible (remove in real run)

print("\nAll indexes created.")


[█████-------------------------] 1/6 pickup.time

In [None]:
# See all indexes
for ix in c.list_indexes():
    print(ix)

# Quick sanity: ensure a query uses your index
plan = c.find(
    {"pickup.location_id": 132, "pickup.time": {"$gte": ISODate("2019-01-01T08:00:00Z")}},
    {"_id": 0}
).explain("executionStats")
print(plan["queryPlanner"]["winningPlan"])


In [None]:
# Step 7 — Query & Validate

from datetime import datetime, timezone
from pymongo import MongoClient

# ---- Connect ----
MONGO_URL = "mongodb://mongouser:mongopassword@mongodb:27017/taxi_logs?authSource=admin"
client = MongoClient(MONGO_URL)
c = client.taxi_logs.trips

# ---- 1) Sanity counts vs Hive (Trips in Jan 2019, UTC) ----
jan1 = datetime(2019, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
feb1 = datetime(2019, 2, 1, 0, 0, 0, tzinfo=timezone.utc)

count_jan = c.count_documents({
    "pickup.time": {"$gte": jan1, "$lt": feb1}
})
print("Trips in 2019-01:", count_jan)

# ---- 2) Targeted query using index (limit 3) ----
docs = list(
    c.find(
        {
            "pickup.location_id": 132,
            "pickup.time": {"$gte": jan1, "$lt": jan1.replace(day=2)}
        },
        # optional projection:
        {"_id": 0, "pickup": 1, "dropoff": 1, "fare_amount": 1}
    ).limit(3)
)
print("Sample docs:", docs)

# ---- 3) Check it uses your index (explain) ----
explain = c.find(
    {
        "pickup.location_id": 132,
        "pickup.time": {"$gte": jan1, "$lt": jan1.replace(day=2)}
    },
    {"_id": 0}
).explain("executionStats")

# Helper to summarize plan quality
def summarize_explain(exp):
    qp = exp.get("queryPlanner", {})
    winning = qp.get("winningPlan", {})
    execstats = exp.get("executionStats", {})
    stage = winning.get("stage", "")
    # Walk nested inputStage(s) to find IXSCAN if present
    def find_stages(node, hits=None):
        if hits is None:
            hits = []
        if not isinstance(node, dict):
            return hits
        if node.get("stage") in ("IXSCAN", "FETCH"):
            hits.append(node.get("stage"))
        for key in ("inputStage", "inputStages", "shards"):
            child = node.get(key)
            if isinstance(child, dict):
                find_stages(child, hits)
            elif isinstance(child, list):
                for ch in child:
                    find_stages(ch, hits)
        return hits
    stages = find_stages(winning)
    print("Winning stage:", stage, "| Stages seen:", stages)
    print("nReturned:", execstats.get("nReturned"))
    print("totalDocsExamined:", execstats.get("totalDocsExamined"))
    print("totalKeysExamined:", execstats.get("totalKeysExamined"))
    # Rule of thumb: keysExamined ≫ 0 and docsExamined ≪ collection size, and IXSCAN present
    if "IXSCAN" in stages:
        print("✅ Index scan detected (good).")
    else:
        print("⚠️ No IXSCAN detected—query may not be using your index.")

summarize_explain(explain)
