# Kafka streaming

First, we modified the docker-compose file to include an additional container for hosing jupyter:

  jupyter:
    image: python:3.9-slim
    container_name: jupyter-kafka
    platform: linux/amd64
    ports:
      - "8888:8888"
    volumes:
      - ./jupyter:/workspace
    working_dir: /workspace
    command: >
      bash -c "pip install jupyterlab pandas scikit-learn pyarrow faust confluent-kafka plotly bytewax &&
               jupyter lab --ip=0.0.0.0 --port=8888 --allow-root --NotebookApp.token=''"
    depends_on:
      - broker1-kr
      - broker2-kr

We wrote a simple producer that loaded one of the .parquet files, ordered it by pickup date and produced a message every 0.5s.

In [None]:
from confluent_kafka import Producer
import pandas as pd
import json
import time

df = pd.read_parquet("./data/part.191.parquet")
df = df.sort_values("pickup_datetime")

# Kafka config
conf = {
    'bootstrap.servers': 'broker1-kr:9092'
}
producer = Producer(conf)

topic = 'yellow_taxi_stream'


for _, row in df.iterrows():
    data = row.to_dict()

    # Convert Timestamps to strings
    for key, value in data.items():
        if isinstance(value, pd.Timestamp):
            data[key] = value.isoformat()

    message = json.dumps(data)
    producer.produce(topic, value=message)
    producer.flush()
    print(f"Produced: {message}")
    time.sleep(0.5)



We also wrote a simple consumer, that read from the same topic and printed out the messages:

In [None]:
from confluent_kafka import Consumer
import json

conf = {
    'bootstrap.servers': 'broker1-kr:9092',
    'group.id': 'taxi_consumer_group',
    'auto.offset.reset': 'earliest'
}
consumer = Consumer(conf)
consumer.subscribe(['yellow_taxi_stream'])

print("Consuming messages from 'yellow_taxi_stream'...")

try:
    while True:
        msg = consumer.poll(1.0)
        if msg is None:
            continue
        if msg.error():
            print("Consumer error:", msg.error())
            continue

        data = json.loads(msg.value().decode('utf-8'))
        print("Consumed:", data)
except KeyboardInterrupt:
    pass
finally:
    consumer.close()


We implemented a simple faust app, that calculates rolling statistics for each borough and writes it back to kafka:

In [None]:
import faust
from typing import Optional
from statistics import mean, stdev

app = faust.App(
    'taxi-stream-app',
    broker='kafka://broker1-kr:9092',
    value_serializer='json',
)

# Schema of incoming messages
class TaxiRecord(faust.Record, serializer='json'):
    total_amount: Optional[float]
    passenger_count: Optional[int]

# Schema for outgoing messages
class StatsRecord(faust.Record, serializer='json'):
    borough: str
    mean_fare: float
    std_fare: float
    mean_psg: float
    std_psg: float
    mean_dist: float
    std_dist: float
    count: int



# Topic from your producer
taxi_topic = app.topic('yellow_taxi_stream', value_type=TaxiRecord)
stats_topic = app.topic('yellow_taxi_stats', value_serializer='json')


# Table for stats
stats_table = app.Table(
    'total_stats',
    default=lambda: {
        'count': 0,
        'total_amounts': [],
        'passengers': [],
        'trip_distance': []
        
    },
    partitions=1,
    changelog_topic=app.topic('custom_stats_changelog', partitions=1)
)


@app.agent(taxi_topic)
async def process(taxis):
    async for taxi in taxis:
        borough = taxi.pickup_borough
        print(taxi.total_amount)
        stats = stats_table[borough]  

        # Update stats
        stats['count'] += 1
        stats['total_amounts'].append(taxi.total_amount or 0)
        stats['passengers'].append(taxi.passenger_count or 0)
        stats['trip_distance'].append(taxi.trip_distance or 0)

        # Maintain a rolling window of last 100 values
        for key in ['total_amounts', 'passengers', 'trip_distance']:
            if len(stats[key]) > 100:
                stats[key].pop(0)
                
        stats_table[borough] = stats
        # Calculate and print rolling mean and std
        if len(stats['total_amounts']) > 1:  # stdev needs at least 2 values
            mean_fare = mean(stats['total_amounts'])
            std_fare = stdev(stats['total_amounts'])
            mean_psg = mean(stats['passengers'])
            std_psg = stdev(stats['passengers'])
            mean_dist = mean(stats['trip_distance'])
            std_dist = stdev(stats['trip_distance'])
            count = stats['count']
            
            print(f"{borough} Count={stats['count']}")
            print(f"  💰 Mean Fare: {mean_fare:.2f}, Std: {std_fare:.2f}")
            print(f"  👥 Mean Passengers: {mean_psg:.2f}, Std: {std_psg:.2f}")
            print(f"  📏 Mean Distance: {mean_dist:.2f}, Std: {std_dist:.2f}")

                
            stats_msg = StatsRecord(
                borough=borough,
                mean_fare=mean_fare,
                std_fare=std_fare,
                mean_psg=mean_psg,
                std_psg=std_psg,
                count=count,
                mean_dist=mean_dist,
                std_dist=std_dist
            )
    
            await stats_topic.send(value=stats_msg)
        



For the stream clustering algorithm we decided to implement K-Means clustering. We update 3 initially chosen centroids based on the incoming data.

In [None]:
import faust
import math
from typing import Optional, List
import numpy as np

app = faust.App(
    'taxi-cluster-app',
    broker='kafka://broker1-kr:9092',
    value_serializer='json',
)

class TaxiRecord(faust.Record, serializer='json'):
    pickup_latitude: Optional[float]
    pickup_longitude: Optional[float]
    total_amount: Optional[float]
    passenger_count: Optional[int]

class ClusterCentroid(faust.Record, serializer='json'):
    cluster_id: int
    lng: float
    lat: float

taxi_topic = app.topic('yellow_taxi_stream', value_type=TaxiRecord)
cluster_topic = app.topic('yellow_taxi_cluster_stream', value_type=ClusterCentroid)

# Number of clusters
K = 3

# Initialize centroids arbitrarily (example coords)
initial_centroids = [
    {'lat': 40.7580, 'lon': -73.9855},  # Times Square approx
    {'lat': 40.7128, 'lon': -74.0060},  # Lower Manhattan approx
    {'lat': 40.730610, 'lon': -73.935242},  # East Village approx
]

# Table: key=cluster_id, value=dict with 'lat', 'lon', 'count'
centroids = app.Table(
    'centroids',
    default=lambda: {'lat': 0.0, 'lon': 0.0, 'count': 0},
    partitions=1,
    changelog_topic=app.topic('custom_stats_changelog_centroid', partitions=1)
)

initialized = False  # Move this outside the agent, at module level

@app.agent(taxi_topic)
async def process(taxis):
    global initialized  # to modify the external variable

    async for taxi in taxis:
        if not initialized:
            for i, c in enumerate(initial_centroids):
                centroids[i] = {'lat': c['lat'], 'lon': c['lon'], 'count': 0}
            initialized = True  # only run once

        if taxi.pickup_latitude is None or taxi.pickup_longitude is None or taxi.pickup_latitude==np.nan or taxi.pickup_longitude==np.nan:
            continue

        def distance(c, lat, lon):
            return math.sqrt((c['lat'] - lat)**2 + (c['lon'] - lon)**2)

        closest_id = min(
            centroids.keys(),
            key=lambda cid: distance(centroids[cid], taxi.pickup_latitude, taxi.pickup_longitude)
        )

        centroid = centroids[closest_id]
        count = centroid['count']

        new_count = count + 1
        new_lat = (centroid['lat'] * count + float(taxi.pickup_latitude)) / new_count
        new_lon = (centroid['lon'] * count + float(taxi.pickup_longitude)) / new_count

        centroids[closest_id] = {'lat': new_lat, 'lon': new_lon, 'count': new_count}
        print(float(taxi.pickup_longitude), float(taxi.pickup_latitude))
        print(f"Cluster {closest_id}: Lat {new_lat:.5f}, Lon {new_lon:.5f}, Count {new_count}")

        cluster_msg = ClusterCentroid(
            cluster_id=closest_id,
            lng=new_lon,
            lat=new_lat
        )
        await cluster_topic.send(value=cluster_msg)
        




We also prepared a simple visualization using plotly:

In [None]:
from confluent_kafka import Consumer
import json
import plotly.express as px
import pandas as pd
import numpy as np
import time
from IPython.display import display, clear_output

NUM_CLUSTERS = 3

centroids = np.array([[40.75, -74.0], [40.73, -73.93], [40.70, -74.15]])

conf = {
    'bootstrap.servers': 'broker1-kr:9092',
    'group.id': 'taxi_consumer_group',
    'auto.offset.reset': 'earliest'
}
consumer = Consumer(conf)
consumer.subscribe(['yellow_taxi_cluster_stream'])

print("Consuming messages from 'yellow_taxi_cluster_stream'...")

try:
    while True:
        msg = consumer.poll(1.0)
        if msg is None:
            continue
        if msg.error():
            print("Consumer error:", msg.error())
            continue

        update = json.loads(msg.value().decode('utf-8'))
        cid = update['cluster_id']
        centroids[cid] = [update['lat'], update['lng']]
    
        df = pd.DataFrame(centroids, columns=["lat", "lon"])
        df["cluster"] = [f"Cluster {i}" for i in range(NUM_CLUSTERS)]
    
        fig = px.scatter_mapbox(df, lat="lat", lon="lon", color="cluster", zoom=9)
        fig.update_layout(mapbox_style="carto-positron",
                          mapbox_center={"lat": 40.73, "lon": -73.98},
                          margin={"r":0,"t":0,"l":0,"b":0})
    
        clear_output(wait=True)
        display(fig)
        time.sleep(3)
except KeyboardInterrupt:
    pass
finally:
    consumer.close()











Lastly we used bytewax python library to calculate rolling statistics:

In [None]:
from bytewax import operators as op
from bytewax.dataflow import Dataflow
from bytewax.connectors.kafka import KafkaSource
from bytewax.connectors.stdio import StdOutSink
import json

WINDOW_SIZE = 5
brokers = ["broker1-kr:9092"]
flow = Dataflow("rolling_avg_per_borough")

# Kafka input
stream = op.input("in", flow, KafkaSource(brokers, ["yellow_taxi_stream"]))

# Parse and key by borough
keyed = op.key_on("key_by_borough", stream, lambda msg: msg.key.decode("utf-8"))
keyed = op.map_value("parse_json", keyed, lambda msg: json.loads(msg.value.decode("utf-8")))
keyed = op.map_value("get_amount", keyed, lambda msg: float(msg["total_amount"]))

# Rolling average calculation
def rolling_avg(state, new_value):
    if state is None:
        state = []

    print("Before state:", state)
    print("New value:", new_value)

    state.append(new_value)
    if len(state) > WINDOW_SIZE:
        state.pop(0)

    print("After state:", state)
    avg = round(sum(state) / len(state), 2)
    print("Running avg:", avg)
    print()

    return (state, avg)

# Apply rolling average per key
rolling_avgs = op.stateful_map("rolling_avg", keyed, rolling_avg)

# Output to stdout
op.output("print_out", rolling_avgs, StdOutSink())
