# PyTorch Tutorial 22: Streaming ML Inference (Real-Time AI Systems)

**The Challenge**: Your model needs to make predictions on **millions of events per second** as they happen - fraud detection, recommendation engines, content moderation, real-time bidding.

**Traditional Approach**: Batch processing every few hours ‚ùå
**Modern Approach**: Stream processing with sub-second latency ‚úÖ

In 2025, streaming ML infrastructure is the standard for production systems at scale. This notebook teaches you:
- How to build real-time inference pipelines with **Kafka + PyTorch**
- What **feature stores** are and why they matter
- How to prevent **training-serving skew** (the #1 production ML bug!)
- Two streaming patterns: **Embedded** vs **Enricher**

## Learning Objectives
1. Understand the **streaming ML architecture**
2. Build a real-time inference system with **Apache Kafka**
3. Integrate **feature stores** for consistent features
4. Avoid **training-serving skew**
5. Deploy streaming ML on **production infrastructure**

---

## Part 1: Vocabulary & Core Concepts

### Key Terms

- **Stream Processing**: Processing data **as it arrives** (not in batches)
- **Apache Kafka**: Distributed event streaming platform (the de facto standard)
- **Apache Flink**: Stream processing framework for real-time analytics
- **Feature Store**: Centralized repository for ML features (online + offline)
- **Training-Serving Skew**: When training and inference use different feature calculations
- **Latency SLA**: Service Level Agreement (e.g., 95% of requests < 100ms)
- **Event-Driven Architecture**: Systems that react to events rather than polling

### Why Streaming ML?

**Batch Inference Problems**:
- ‚ùå Stale predictions (hours/days old)
- ‚ùå Can't react to real-time events
- ‚ùå Wastes compute on unchanged data

**Streaming Inference Benefits**:
- ‚úÖ Sub-second latency
- ‚úÖ React to events immediately
- ‚úÖ Only compute when needed
- ‚úÖ Better user experience

### Real-World Use Cases
- **Fraud Detection**: Flag suspicious transactions in <50ms
- **Personalized Recommendations**: Update recommendations as user browses
- **Content Moderation**: Filter harmful content before it's visible
- **Real-Time Bidding**: Predict ad click probability in <10ms
- **Anomaly Detection**: Detect system failures as they happen

## Part 2: Streaming Architecture Patterns

### Pattern 1: Embedded Model
Model is **embedded directly** into the streaming application.

```
Kafka Stream ‚Üí Flink App (with model loaded) ‚Üí Predictions ‚Üí Output Stream
```

**Pros**: Lowest latency, simple deployment
**Cons**: Hard to update model, duplicates model across instances

### Pattern 2: Enricher (Model Service)
Streaming app **calls a separate ML service** via gRPC/REST.

```
Kafka Stream ‚Üí Enricher App ‚Üí [gRPC call] ‚Üí Model Service ‚Üí Response ‚Üí Output Stream
```

**Pros**: Easy model updates, shared service, version control
**Cons**: Extra network hop (~5-10ms latency)

### Which to Use?
| Requirement | Pattern |
|-------------|----------|
| Ultra-low latency (< 10ms) | Embedded |
| Frequent model updates | Enricher |
| Multiple models/versions | Enricher |
| Simple model, stable | Embedded |

## Part 3: Building a Real-Time Fraud Detection System

Let's build a simplified fraud detection pipeline:
1. Kafka receives transaction events
2. PyTorch model scores each transaction
3. Output flagged transactions to another Kafka topic

### Step 1: Setup (Conceptual - requires Kafka)

```bash
# Install dependencies
pip install kafka-python torch

# Start Kafka (Docker)
docker run -d -p 9092:9092 apache/kafka
```

In [None]:
import torch
import torch.nn as nn
import json
import time
from typing import Dict, Any

# Simulated Fraud Detection Model
class FraudDetectionModel(nn.Module):
    def __init__(self, input_dim=10):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()  # Output: fraud probability [0, 1]
        )
    
    def forward(self, x):
        return self.network(x)

# Load pretrained model (in production, load from S3/GCS)
model = FraudDetectionModel()
model.eval()  # Inference mode

print("‚úÖ Fraud detection model loaded")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Simulate Feature Engineering (in production, use Feature Store)
def extract_features(transaction: Dict[str, Any]) -> torch.Tensor:
    """
    Convert raw transaction to ML features.
    In production, fetch from Feature Store (Feast, Tecton, etc.)
    """
    features = [
        transaction['amount'],
        transaction['merchant_risk_score'],
        transaction['user_velocity_1h'],  # Transactions in last hour
        transaction['user_velocity_24h'],
        transaction['amount_vs_avg_ratio'],
        transaction['time_since_last_txn_minutes'],
        transaction['is_foreign'],  # Boolean ‚Üí 0 or 1
        transaction['device_trust_score'],
        transaction['merchant_category_risk'],
        transaction['hour_of_day'] / 24.0  # Normalize
    ]
    return torch.tensor(features, dtype=torch.float32)

# Test feature extraction
sample_transaction = {
    'transaction_id': 'txn_12345',
    'amount': 450.00,
    'merchant_risk_score': 0.3,
    'user_velocity_1h': 2,
    'user_velocity_24h': 15,
    'amount_vs_avg_ratio': 2.1,
    'time_since_last_txn_minutes': 45,
    'is_foreign': 0,
    'device_trust_score': 0.85,
    'merchant_category_risk': 0.2,
    'hour_of_day': 14
}

features = extract_features(sample_transaction)
print(f"Extracted features: {features}")
print(f"Feature shape: {features.shape}")

In [None]:
# Streaming Inference Loop (Simulated)
class StreamingInferenceEngine:
    def __init__(self, model, fraud_threshold=0.7):
        self.model = model
        self.fraud_threshold = fraud_threshold
        self.stats = {'processed': 0, 'fraudulent': 0, 'avg_latency_ms': 0}
    
    def process_event(self, transaction: Dict) -> Dict:
        """Process a single transaction event."""
        start = time.time()
        
        # 1. Extract features
        features = extract_features(transaction).unsqueeze(0)  # Add batch dim
        
        # 2. Run inference
        with torch.no_grad():
            fraud_score = self.model(features).item()
        
        # 3. Make decision
        is_fraud = fraud_score > self.fraud_threshold
        
        latency = (time.time() - start) * 1000  # Convert to ms
        
        # 4. Update stats
        self.stats['processed'] += 1
        if is_fraud:
            self.stats['fraudulent'] += 1
        self.stats['avg_latency_ms'] = (
            (self.stats['avg_latency_ms'] * (self.stats['processed'] - 1) + latency)
            / self.stats['processed']
        )
        
        return {
            'transaction_id': transaction['transaction_id'],
            'fraud_score': fraud_score,
            'is_fraud': is_fraud,
            'latency_ms': latency
        }
    
    def get_stats(self):
        return self.stats

# Initialize engine
engine = StreamingInferenceEngine(model, fraud_threshold=0.7)

# Simulate streaming events
import random

print("\nSimulating real-time transaction stream:\n")
for i in range(10):
    # Simulate transaction (in production, this comes from Kafka)
    txn = {
        'transaction_id': f'txn_{i+1:05d}',
        'amount': random.uniform(10, 1000),
        'merchant_risk_score': random.random(),
        'user_velocity_1h': random.randint(0, 10),
        'user_velocity_24h': random.randint(0, 50),
        'amount_vs_avg_ratio': random.uniform(0.5, 3.0),
        'time_since_last_txn_minutes': random.randint(1, 120),
        'is_foreign': random.choice([0, 1]),
        'device_trust_score': random.random(),
        'merchant_category_risk': random.random(),
        'hour_of_day': random.randint(0, 23)
    }
    
    result = engine.process_event(txn)
    
    status = "üö® FRAUD" if result['is_fraud'] else "‚úÖ Clean"
    print(f"{status} | {result['transaction_id']} | Score: {result['fraud_score']:.3f} | {result['latency_ms']:.2f}ms")

print(f"\nüìä Statistics:")
stats = engine.get_stats()
print(f"Processed: {stats['processed']} transactions")
print(f"Fraudulent: {stats['fraudulent']} ({stats['fraudulent']/stats['processed']*100:.1f}%)")
print(f"Average latency: {stats['avg_latency_ms']:.2f}ms")

## Part 4: Feature Stores - The Missing Piece

### The Problem: Training-Serving Skew

**Scenario**: You train a model using Spark (batch), but serve it in real-time (streaming).

**What goes wrong**:
```python
# Training (Spark)
user_avg_txn = df.groupBy('user_id').agg(avg('amount'))  # Pandas/Spark

# Serving (Python)
user_avg_txn = sum(amounts) / len(amounts)  # Native Python
```

**Result**: Slight differences in float precision, rounding, or logic ‚Üí **model performs worse in production!**

### The Solution: Feature Store

A Feature Store provides:
1. **Single source of truth** for feature definitions
2. **Offline store** (historical data for training)
3. **Online store** (low-latency lookup for inference)
4. **Automatic sync** between offline and online

### Popular Feature Stores
- **Feast** (Open Source, most popular)
- **Tecton** (Enterprise)
- **AWS SageMaker Feature Store**
- **Databricks Feature Store**
- **Hopsworks**

In [None]:
# Conceptual Example: Using Feast Feature Store

feast_example = '''
from feast import FeatureStore
import pandas as pd

# Initialize Feast
store = FeatureStore(repo_path=".")

# Define features once
entity_rows = [
    {"user_id": "user_123", "event_timestamp": datetime.now()}
]

# Get features for inference (from online store - Redis/DynamoDB)
features = store.get_online_features(
    features=[
        "user_features:avg_transaction_amount",
        "user_features:transaction_count_24h",
        "user_features:fraud_history_score"
    ],
    entity_rows=entity_rows
).to_dict()

# These are THE SAME features used during training!
# No training-serving skew!
'''

print("Feast Feature Store Usage:")
print(feast_example)
print("\n‚úÖ Key benefit: Training and serving use IDENTICAL feature logic!")

## Part 5: Production Kafka Integration

### Real Kafka Consumer Example

In [None]:
# Production-ready Kafka consumer with PyTorch inference
# (Requires: pip install kafka-python)

kafka_consumer_code = '''
from kafka import KafkaConsumer, KafkaProducer
import json
import torch

# Initialize Kafka consumer
consumer = KafkaConsumer(
    'transactions',  # Input topic
    bootstrap_servers=['localhost:9092'],
    value_deserializer=lambda m: json.loads(m.decode('utf-8')),
    auto_offset_reset='latest',
    enable_auto_commit=True,
    group_id='fraud-detection-service'
)

# Initialize Kafka producer
producer = KafkaProducer(
    bootstrap_servers=['localhost:9092'],
    value_serializer=lambda m: json.dumps(m).encode('utf-8')
)

# Load model once at startup
model = FraudDetectionModel()
model.eval()

print("üöÄ Streaming inference service started...")

# Main processing loop
for message in consumer:
    transaction = message.value
    
    # Extract features
    features = extract_features(transaction).unsqueeze(0)
    
    # Inference
    with torch.no_grad():
        fraud_score = model(features).item()
    
    # Publish result to output topic
    result = {
        'transaction_id': transaction['transaction_id'],
        'fraud_score': fraud_score,
        'is_fraud': fraud_score > 0.7,
        'timestamp': time.time()
    }
    
    producer.send('fraud-scores', value=result)
    
    if result['is_fraud']:
        print(f"üö® Fraud detected: {result['transaction_id']} (score: {fraud_score:.3f})")
'''

print("Production Kafka + PyTorch Integration:")
print(kafka_consumer_code)
print("\n‚ö° This processes millions of transactions per day in real-time!")

## Part 6: Monitoring & Observability

### Critical Metrics for Streaming ML

1. **Latency Metrics**
   - P50, P95, P99 latency
   - Time-to-first-byte (TTFB)
   
2. **Throughput**
   - Events/second processed
   - Backlog/lag (messages waiting)
   
3. **Model Quality**
   - Prediction distribution drift
   - Feature distribution drift
   - Online accuracy (when labels arrive)
   
4. **System Health**
   - Consumer lag (Kafka)
   - Error rate
   - Memory/CPU utilization

### Example: Prometheus Metrics

In [None]:
# Example monitoring setup (requires prometheus_client)

monitoring_code = '''
from prometheus_client import Counter, Histogram, Gauge, start_http_server

# Define metrics
transactions_processed = Counter(
    'fraud_detection_transactions_total',
    'Total transactions processed'
)

fraud_detected = Counter(
    'fraud_detection_fraud_total',
    'Total fraudulent transactions detected'
)

inference_latency = Histogram(
    'fraud_detection_latency_seconds',
    'Inference latency in seconds',
    buckets=[0.001, 0.005, 0.01, 0.05, 0.1, 0.5]
)

fraud_score_dist = Histogram(
    'fraud_detection_score_distribution',
    'Distribution of fraud scores',
    buckets=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
)

# Start metrics server
start_http_server(8000)  # Expose metrics at :8000/metrics

# In your inference loop:
with inference_latency.time():
    fraud_score = model(features).item()

transactions_processed.inc()
fraud_score_dist.observe(fraud_score)

if fraud_score > 0.7:
    fraud_detected.inc()
'''

print("Production Monitoring Setup:")
print(monitoring_code)
print("\nüìä Grafana dashboards can visualize these metrics in real-time!")

## Part 7: Best Practices & Anti-Patterns

### ‚úÖ DO
1. **Use feature stores** to avoid training-serving skew
2. **Monitor model drift** - retrain when performance degrades
3. **Set latency SLAs** and alert when violated
4. **Implement circuit breakers** - fallback if model fails
5. **Version your models** - enable rollbacks
6. **Load models once** at startup, not per request
7. **Batch micro-batches** (e.g., 10ms window) for throughput

### ‚ùå DON'T
1. **Don't recompute features** differently in training vs serving
2. **Don't ignore backpressure** - handle slow consumers
3. **Don't deploy without monitoring** - you'll be blind
4. **Don't use batch pipelines** for real-time use cases
5. **Don't skip model validation** before deploying
6. **Don't ignore data quality** - garbage in, garbage out

### Handling Edge Cases
- **Missing features**: Have default values or skip prediction
- **Model errors**: Return safe default (e.g., flag as suspicious)
- **Kafka downtime**: Buffer locally, implement retry logic
- **Feature store lag**: Cache last known values (with TTL)

## Summary: The Streaming ML Stack (2025)

### Architecture Components
```
Data Sources ‚Üí Kafka ‚Üí [Streaming App + PyTorch Model] ‚Üí Kafka ‚Üí Downstream Systems
                         ‚Üì (feature lookup)
                    Feature Store (Redis/DynamoDB)
                         ‚Üì (metrics)
                   Prometheus ‚Üí Grafana
```

### Technology Stack
- **Messaging**: Apache Kafka (industry standard)
- **Processing**: Flink, Kafka Streams, or custom Python
- **Feature Store**: Feast, Tecton, or cloud-native
- **Monitoring**: Prometheus + Grafana
- **Model Serving**: Embedded or separate service

### Performance Targets
- **Latency**: < 50ms for P95
- **Throughput**: 10,000+ events/second per instance
- **Availability**: 99.9% uptime
- **Model freshness**: < 1 hour lag from training

### What FAANG Expects You to Know
‚úÖ Difference between batch and streaming inference
‚úÖ How to prevent training-serving skew
‚úÖ What feature stores are and why they matter
‚úÖ Kafka integration patterns
‚úÖ How to monitor streaming ML systems
‚úÖ When to use embedded vs enricher pattern
‚úÖ Latency optimization techniques

### Further Reading
- [Kafka Documentation](https://kafka.apache.org/documentation/)
- [Feast Feature Store](https://feast.dev/)
- [Real-time ML with Kafka and Flink](https://www.kai-waehner.de/blog/2024/10/01/real-time-model-inference-with-apache-kafka-and-flink-for-predictive-ai-and-genai/)

**You now understand production streaming ML infrastructure! üöÄ**