
# Trip Ingestion Test Notebook

This notebook demonstrates how to test the mandatory features of the **Trip Ingestion API**.

It will cover:
1. Uploading and ingesting a CSV file.
2. Receiving ingestion status updates (via WebSockets).
3. Querying the weekly average number of trips by region or bounding box.


In [1]:
# Standalone test for ingest_task logic with POINT coordinates
import csv
import geohash2
from datetime import datetime
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
import re

DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/tripsdb"
engine = create_engine(DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(bind=engine)

def parse_point(point_str):
    """
    Convert 'POINT (lon lat)' string into (lat, lon) tuple.
    """
    match = re.match(r"POINT \(([-\d.]+) ([-\d.]+)\)", point_str)
    if match:
        lon, lat = match.groups()
        return float(lat), float(lon)
    else:
        raise ValueError(f"Invalid POINT format: {point_str}")

def get_geohash(lat, lon, precision=5):
    return geohash2.encode(float(lat), float(lon), precision=precision)

def get_time_bucket(dt_str):
    dt = datetime.fromisoformat(dt_str)
    hour = dt.hour
    if 6 <= hour < 12:
        return 'morning'
    elif 12 <= hour < 18:
        return 'afternoon'
    elif 18 <= hour < 24:
        return 'evening'
    else:
        return 'night'

with open('../source_data/trips.csv', 'r', encoding='utf-8') as f:
    decoded = f.read().splitlines()

reader = csv.DictReader(decoded)
session = SessionLocal()
inserted_rows = 0

for row in reader:
    # Parse POINT coordinates
    origin_lat, origin_lon = parse_point(row.get("origin_coord"))
    dest_lat, dest_lon = parse_point(row.get("destination_coord"))
    trip_dt = row.get("datetime")
    
    origin_geohash = get_geohash(origin_lat, origin_lon)
    dest_geohash = get_geohash(dest_lat, dest_lon)
    tod_bucket = get_time_bucket(trip_dt)

    session.execute(
        text("""
        INSERT INTO trips (
            region, origin_coord, destination_coord, trip_datetime, datasource,
            origin_geohash, dest_geohash, tod_bucket
        ) VALUES (
            :region,
            ST_Point(:origin_lon, :origin_lat),
            ST_Point(:dest_lon, :dest_lat),
            :trip_datetime,
            :datasource,
            :origin_geohash,
            :dest_geohash,
            :tod_bucket
        )
        """),
        {
            "region": row["region"],
            "origin_lat": origin_lat,
            "origin_lon": origin_lon,
            "dest_lat": dest_lat,
            "dest_lon": dest_lon,
            "trip_datetime": trip_dt,
            "datasource": row["datasource"],
            "origin_geohash": origin_geohash,
            "dest_geohash": dest_geohash,
            "tod_bucket": tod_bucket,
        }
    )
    inserted_rows += 1

session.commit()
session.close()

print(f"Inserted {inserted_rows} rows from trips.csv")


Inserted 100 rows from trips.csv


In [None]:

import os
import requests
import pandas as pd
from dotenv import load_dotenv
import websockets
import asyncio

# Load environment variables
load_dotenv("../.env")
API_PORT = os.getenv("API_PORT", "8000")
API_URL = f"http://localhost:{API_PORT}"


In [None]:

# Upload a CSV for ingestion
csv_path = "../source_data/trips.csv"

with open(csv_path, "rb") as f:
    response = requests.post(f"{API_URL}/ingest_csv/", files={"file": f})
    print(response.json())


In [None]:

# Check ingestion status via WebSocket
async def check_status():
    uri = f"ws://localhost:{API_PORT}/ws/status"
    async with websockets.connect(uri) as websocket:
        status = await websocket.recv()
        print("Status:", status)

await check_status()


In [None]:

# Query weekly average trips by region
params = {"region": "Prague"}
response = requests.get(f"{API_URL}/trips/weekly_average/", params=params)
print(response.json())


In [None]:

# Query weekly average trips by bounding box
params = {"bbox": "14.40,49.90,14.60,50.10"}
response = requests.get(f"{API_URL}/trips/weekly_average/", params=params)
print(response.json())
