# ML System Design Patterns

## Overview
Core ML system design patterns tested in FAANG interviews:
- **Recommendation Systems**: Two-tower, collaborative filtering, content-based
- **Search Ranking**: Learning-to-rank, query understanding, relevance
- **Fraud Detection**: Real-time scoring, anomaly detection, imbalanced data
- **Ads/CTR Prediction**: Click models, calibration, explore-exploit
- **Content Moderation**: Multi-label classification, active learning

## FAANG Interview Framework
1. **Clarify** (10%): Requirements, scale, constraints
2. **High-Level Design** (20%): Components, data flow
3. **Deep Dive** (50%): Model architecture, features, training
4. **Evaluation** (20%): Metrics, A/B testing, monitoring

## Common Questions
- Design Instagram's Explore page (Meta)
- Design YouTube search ranking (Google)
- Design fraud detection for Stripe (Stripe/TikTok)
- Design ad click prediction (Meta/Google)

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict
import time

print("ML System Design Patterns - FAANG Interview Prep")

---
# Part 1: Recommendation Systems

## The Two-Tower Architecture
Standard pattern for large-scale recommendations (YouTube, Instagram, TikTok)

```
User Tower          Item Tower
    |                   |
User Features      Item Features
    |                   |
  [MLP]               [MLP]
    |                   |
User Embedding     Item Embedding
    \                  /
     \                /
      Dot Product Score
```

In [None]:
class TwoTowerModel(nn.Module):
    """
    Two-Tower architecture for candidate generation.
    
    Used by: YouTube, Instagram Reels, TikTok
    Scale: Millions of items, billions of users
    Latency: ~10ms for retrieval
    """
    
    def __init__(self, 
                 num_users: int,
                 num_items: int,
                 embedding_dim: int = 64,
                 hidden_dims: List[int] = [128, 64]):
        super().__init__()
        
        self.embedding_dim = embedding_dim
        
        # User tower
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.user_mlp = self._build_mlp(
            embedding_dim, hidden_dims, embedding_dim
        )
        
        # Item tower
        self.item_embedding = nn.Embedding(num_items, embedding_dim)
        self.item_mlp = self._build_mlp(
            embedding_dim, hidden_dims, embedding_dim
        )
        
        # Temperature for softmax
        self.temperature = nn.Parameter(torch.ones(1))
    
    def _build_mlp(self, input_dim: int, hidden_dims: List[int], 
                   output_dim: int) -> nn.Sequential:
        """Build MLP tower."""
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(0.2)
            ])
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, output_dim))
        return nn.Sequential(*layers)
    
    def get_user_embedding(self, user_ids: torch.Tensor) -> torch.Tensor:
        """Get user embeddings (for serving)."""
        user_emb = self.user_embedding(user_ids)
        user_vec = self.user_mlp(user_emb)
        return F.normalize(user_vec, p=2, dim=-1)
    
    def get_item_embedding(self, item_ids: torch.Tensor) -> torch.Tensor:
        """Get item embeddings (pre-computed for ANN index)."""
        item_emb = self.item_embedding(item_ids)
        item_vec = self.item_mlp(item_emb)
        return F.normalize(item_vec, p=2, dim=-1)
    
    def forward(self, user_ids: torch.Tensor, 
                positive_ids: torch.Tensor,
                negative_ids: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        """
        Forward pass with in-batch negatives.
        
        Training strategy: Use other items in batch as negatives.
        """
        user_vec = self.get_user_embedding(user_ids)
        pos_vec = self.get_item_embedding(positive_ids)
        
        # Positive scores
        pos_scores = torch.sum(user_vec * pos_vec, dim=-1)
        
        # In-batch negatives: all items in batch are potential negatives
        # Shape: (batch_size, batch_size)
        all_scores = torch.matmul(user_vec, pos_vec.T) / self.temperature
        
        # Labels: diagonal is positive (index i matches user i)
        labels = torch.arange(len(user_ids), device=user_ids.device)
        
        # Cross-entropy loss (treats as multi-class classification)
        loss = F.cross_entropy(all_scores, labels)
        
        return {
            'loss': loss,
            'pos_scores': pos_scores,
            'user_embeddings': user_vec,
            'item_embeddings': pos_vec
        }

# Example
print("\n=== Two-Tower Recommendation Model ===")
model = TwoTowerModel(num_users=10000, num_items=100000)

# Training batch
user_ids = torch.randint(0, 10000, (32,))
item_ids = torch.randint(0, 100000, (32,))

output = model(user_ids, item_ids)
print(f"Loss: {output['loss'].item():.4f}")
print(f"User embedding shape: {output['user_embeddings'].shape}")
print(f"Item embedding shape: {output['item_embeddings'].shape}")

In [None]:
class CandidateGeneration:
    """
    Stage 1: Retrieve candidates from millions of items.
    
    Techniques:
    - ANN search (FAISS, ScaNN)
    - Collaborative filtering
    - Content-based filtering
    """
    
    def __init__(self, embedding_dim: int = 64):
        self.embedding_dim = embedding_dim
        self.item_index = None  # FAISS index in production
        self.item_embeddings = None
    
    def build_index(self, item_embeddings: np.ndarray):
        """Build ANN index for fast retrieval."""
        # In production: Use FAISS or ScaNN
        # faiss.IndexFlatIP or faiss.IndexIVFFlat
        self.item_embeddings = item_embeddings
        print(f"Built index with {len(item_embeddings)} items")
    
    def retrieve(self, user_embedding: np.ndarray, k: int = 100) -> List[Tuple[int, float]]:
        """
        Retrieve top-k candidates using ANN.
        
        Latency target: < 10ms
        """
        if self.item_embeddings is None:
            return []
        
        # Cosine similarity (dot product for normalized vectors)
        scores = np.dot(self.item_embeddings, user_embedding)
        
        # Top-k
        top_indices = np.argsort(scores)[-k:][::-1]
        
        return [(idx, scores[idx]) for idx in top_indices]

class Ranker(nn.Module):
    """
    Stage 2: Rank candidates with a heavy model.
    
    Features:
    - User features (demographics, history)
    - Item features (content, popularity)
    - Context features (time, device)
    - Cross features (user-item interactions)
    """
    
    def __init__(self, feature_dim: int, hidden_dims: List[int] = [256, 128, 64]):
        super().__init__()
        
        layers = []
        prev_dim = feature_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(0.3)
            ])
            prev_dim = hidden_dim
        
        # Multi-task heads
        self.shared = nn.Sequential(*layers)
        self.click_head = nn.Linear(prev_dim, 1)  # P(click)
        self.engage_head = nn.Linear(prev_dim, 1)  # P(engagement)
        self.share_head = nn.Linear(prev_dim, 1)  # P(share)
    
    def forward(self, features: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Predict multiple objectives."""
        shared_rep = self.shared(features)
        
        return {
            'p_click': torch.sigmoid(self.click_head(shared_rep)),
            'p_engage': torch.sigmoid(self.engage_head(shared_rep)),
            'p_share': torch.sigmoid(self.share_head(shared_rep))
        }
    
    def compute_ranking_score(self, predictions: Dict[str, torch.Tensor],
                              weights: Dict[str, float] = None) -> torch.Tensor:
        """
        Combine predictions into final ranking score.
        
        Score = w1 * P(click) + w2 * P(engage) + w3 * P(share)
        """
        if weights is None:
            weights = {'p_click': 1.0, 'p_engage': 2.0, 'p_share': 3.0}
        
        score = sum(weights[k] * v for k, v in predictions.items())
        return score

print("\n=== Candidate Generation + Ranking ===")
print("Stage 1 (Retrieval): 1M items -> 100 candidates (<10ms)")
print("Stage 2 (Ranking): 100 candidates -> 20 ranked (<50ms)")
print("Stage 3 (Re-ranking): Business rules, diversity (<10ms)")

---
# Part 2: Search Ranking (Learning-to-Rank)

## The Search Pipeline
```
Query -> Query Understanding -> Retrieval -> Ranking -> Re-ranking -> Results
```

In [None]:
class QueryUnderstanding:
    """
    Parse and enrich user queries.
    
    Components:
    - Query classification (navigational, informational, transactional)
    - Entity recognition (brands, categories)
    - Query expansion (synonyms, spelling correction)
    - Intent detection
    """
    
    def __init__(self):
        self.stop_words = {'the', 'a', 'an', 'is', 'are', 'for', 'to'}
    
    def parse(self, query: str) -> Dict[str, Any]:
        """Parse and enrich query."""
        tokens = query.lower().split()
        
        return {
            'original': query,
            'tokens': tokens,
            'filtered_tokens': [t for t in tokens if t not in self.stop_words],
            'intent': self._detect_intent(query),
            'entities': self._extract_entities(query),
            'query_type': self._classify_query(query)
        }
    
    def _detect_intent(self, query: str) -> str:
        """Detect user intent."""
        query_lower = query.lower()
        
        if any(w in query_lower for w in ['buy', 'price', 'cheap', 'deal']):
            return 'transactional'
        elif any(w in query_lower for w in ['how', 'what', 'why', 'tutorial']):
            return 'informational'
        else:
            return 'navigational'
    
    def _extract_entities(self, query: str) -> List[Dict]:
        """Extract named entities."""
        # In production: Use NER model
        entities = []
        
        # Simple brand detection
        brands = ['apple', 'samsung', 'nike', 'sony']
        for brand in brands:
            if brand in query.lower():
                entities.append({'type': 'brand', 'value': brand})
        
        return entities
    
    def _classify_query(self, query: str) -> str:
        """Classify query type."""
        if len(query.split()) == 1:
            return 'head'  # Popular, short
        elif len(query.split()) <= 3:
            return 'torso'  # Medium frequency
        else:
            return 'tail'  # Long, specific

class LearningToRank(nn.Module):
    """
    Learning-to-Rank model for search.
    
    Approaches:
    - Pointwise: Predict relevance score
    - Pairwise: Predict which doc is more relevant (RankNet)
    - Listwise: Optimize list-level metric (LambdaRank, ListNet)
    """
    
    def __init__(self, feature_dim: int):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    
    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """Predict relevance score."""
        return self.model(features).squeeze(-1)
    
    def pairwise_loss(self, scores_i: torch.Tensor, scores_j: torch.Tensor,
                      labels: torch.Tensor) -> torch.Tensor:
        """
        RankNet pairwise loss.
        
        labels: 1 if doc_i > doc_j, 0 otherwise
        """
        diff = scores_i - scores_j
        return F.binary_cross_entropy_with_logits(diff, labels.float())
    
    def listwise_loss(self, scores: torch.Tensor, 
                      relevance: torch.Tensor) -> torch.Tensor:
        """
        ListNet loss (softmax cross-entropy).
        
        scores: (batch_size, num_docs)
        relevance: (batch_size, num_docs) - ground truth relevance
        """
        # Convert to probability distributions
        pred_dist = F.softmax(scores, dim=-1)
        true_dist = F.softmax(relevance.float(), dim=-1)
        
        # Cross-entropy
        loss = -torch.sum(true_dist * torch.log(pred_dist + 1e-10), dim=-1)
        return loss.mean()

def compute_ndcg(relevance: List[int], k: int = 10) -> float:
    """
    Normalized Discounted Cumulative Gain.
    
    Primary metric for search ranking.
    """
    def dcg(rel: List[int], k: int) -> float:
        rel = rel[:k]
        return sum((2**r - 1) / np.log2(i + 2) for i, r in enumerate(rel))
    
    # DCG of predicted ranking
    actual_dcg = dcg(relevance, k)
    
    # Ideal DCG (sorted by relevance)
    ideal_dcg = dcg(sorted(relevance, reverse=True), k)
    
    if ideal_dcg == 0:
        return 0.0
    
    return actual_dcg / ideal_dcg

# Example
print("\n=== Search Ranking Example ===")
qu = QueryUnderstanding()
parsed = qu.parse("buy cheap apple iphone")
print(f"Query: {parsed['original']}")
print(f"Intent: {parsed['intent']}")
print(f"Entities: {parsed['entities']}")
print(f"Type: {parsed['query_type']}")

# NDCG example
relevance = [3, 2, 3, 0, 1, 2]  # Graded relevance (0-3)
print(f"\nNDCG@5: {compute_ndcg(relevance, k=5):.4f}")

---
# Part 3: Fraud Detection

## Real-Time Fraud Scoring System
Requirements:
- Latency: < 100ms per transaction
- Scale: Millions of transactions/day
- Imbalanced: ~0.1% fraud rate

In [None]:
@dataclass
class Transaction:
    """Transaction data for fraud detection."""
    transaction_id: str
    user_id: str
    amount: float
    merchant_id: str
    timestamp: float
    device_id: str
    ip_address: str
    location: Tuple[float, float]  # lat, lon
    card_type: str
    is_international: bool

class FraudFeatureEngine:
    """
    Feature engineering for fraud detection.
    
    Feature categories:
    1. Transaction features (amount, time, location)
    2. User behavior features (velocity, patterns)
    3. Device/network features (device fingerprint, IP)
    4. Merchant features (risk score, category)
    5. Graph features (connections, clusters)
    """
    
    def __init__(self):
        self.user_history = defaultdict(list)
        self.device_history = defaultdict(list)
    
    def extract_features(self, txn: Transaction) -> Dict[str, float]:
        """Extract all features for a transaction."""
        features = {}
        
        # Transaction features
        features['amount'] = txn.amount
        features['amount_log'] = np.log1p(txn.amount)
        features['hour_of_day'] = (txn.timestamp % 86400) / 3600
        features['is_night'] = 1 if features['hour_of_day'] < 6 or features['hour_of_day'] > 22 else 0
        features['is_international'] = float(txn.is_international)
        
        # User velocity features
        user_txns = self.user_history[txn.user_id]
        features['txn_count_1h'] = self._count_recent(user_txns, txn.timestamp, 3600)
        features['txn_count_24h'] = self._count_recent(user_txns, txn.timestamp, 86400)
        features['amount_sum_24h'] = self._sum_recent_amounts(user_txns, txn.timestamp, 86400)
        
        # Deviation from user's normal behavior
        if user_txns:
            avg_amount = np.mean([t['amount'] for t in user_txns[-100:]])
            features['amount_deviation'] = (txn.amount - avg_amount) / (avg_amount + 1)
        else:
            features['amount_deviation'] = 0
        
        # Device features
        device_txns = self.device_history[txn.device_id]
        features['device_user_count'] = len(set(t['user_id'] for t in device_txns))
        features['new_device'] = 1 if txn.device_id not in self.device_history else 0
        
        # Update history
        self._update_history(txn)
        
        return features
    
    def _count_recent(self, history: List[Dict], timestamp: float, 
                      window_seconds: float) -> int:
        """Count transactions in time window."""
        cutoff = timestamp - window_seconds
        return sum(1 for t in history if t['timestamp'] > cutoff)
    
    def _sum_recent_amounts(self, history: List[Dict], timestamp: float,
                           window_seconds: float) -> float:
        """Sum transaction amounts in time window."""
        cutoff = timestamp - window_seconds
        return sum(t['amount'] for t in history if t['timestamp'] > cutoff)
    
    def _update_history(self, txn: Transaction):
        """Update user and device history."""
        record = {
            'timestamp': txn.timestamp,
            'amount': txn.amount,
            'user_id': txn.user_id
        }
        self.user_history[txn.user_id].append(record)
        self.device_history[txn.device_id].append(record)

class FraudDetector(nn.Module):
    """
    Neural fraud detection model.
    
    Handles imbalanced data with:
    - Focal loss
    - Class weights
    - SMOTE/oversampling
    """
    
    def __init__(self, feature_dim: int, hidden_dims: List[int] = [128, 64, 32]):
        super().__init__()
        
        layers = []
        prev_dim = feature_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(0.3)
            ])
            prev_dim = hidden_dim
        
        self.network = nn.Sequential(*layers)
        self.classifier = nn.Linear(prev_dim, 1)
    
    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """Predict fraud probability."""
        hidden = self.network(features)
        logits = self.classifier(hidden)
        return torch.sigmoid(logits).squeeze(-1)
    
    def focal_loss(self, predictions: torch.Tensor, targets: torch.Tensor,
                   gamma: float = 2.0, alpha: float = 0.75) -> torch.Tensor:
        """
        Focal Loss for imbalanced classification.
        
        Down-weights easy examples, focuses on hard ones.
        """
        bce = F.binary_cross_entropy(predictions, targets.float(), reduction='none')
        
        # Focal weight
        pt = torch.where(targets == 1, predictions, 1 - predictions)
        focal_weight = (1 - pt) ** gamma
        
        # Class weight (alpha for positive class)
        class_weight = torch.where(targets == 1, alpha, 1 - alpha)
        
        loss = focal_weight * class_weight * bce
        return loss.mean()

def compute_fraud_metrics(y_true: np.ndarray, y_pred: np.ndarray,
                          threshold: float = 0.5) -> Dict[str, float]:
    """
    Compute fraud detection metrics.
    
    Key metrics:
    - Precision: Of predicted frauds, how many are real?
    - Recall: Of real frauds, how many did we catch?
    - False Positive Rate: Good transactions flagged as fraud
    """
    y_binary = (y_pred >= threshold).astype(int)
    
    tp = np.sum((y_binary == 1) & (y_true == 1))
    fp = np.sum((y_binary == 1) & (y_true == 0))
    fn = np.sum((y_binary == 0) & (y_true == 1))
    tn = np.sum((y_binary == 0) & (y_true == 0))
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
    
    # For imbalanced data, use PR-AUC not ROC-AUC
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0,
        'false_positive_rate': fpr,
        'fraud_caught': recall,
        'good_txn_blocked': fpr
    }

# Example
print("\n=== Fraud Detection System ===")
feature_engine = FraudFeatureEngine()

txn = Transaction(
    transaction_id="txn_001",
    user_id="user_123",
    amount=999.99,
    merchant_id="merchant_456",
    timestamp=time.time(),
    device_id="device_789",
    ip_address="192.168.1.1",
    location=(37.7749, -122.4194),
    card_type="visa",
    is_international=False
)

features = feature_engine.extract_features(txn)
print(f"Extracted {len(features)} features")
print(f"Sample features: amount={features['amount']}, amount_deviation={features['amount_deviation']:.2f}")

---
# Part 4: Ads / CTR Prediction

## Click-Through Rate Prediction
Core problem for Meta, Google, TikTok ads systems.

In [None]:
class DeepFM(nn.Module):
    """
    DeepFM: Deep Factorization Machine for CTR prediction.
    
    Combines:
    - FM (Factorization Machine): Captures 2nd-order feature interactions
    - DNN: Captures high-order feature interactions
    """
    
    def __init__(self, 
                 num_fields: int,
                 embedding_dim: int = 8,
                 hidden_dims: List[int] = [256, 128, 64]):
        super().__init__()
        
        self.num_fields = num_fields
        self.embedding_dim = embedding_dim
        
        # Embeddings for each field
        self.embeddings = nn.ModuleList([
            nn.Embedding(10000, embedding_dim)  # Assume max 10k values per field
            for _ in range(num_fields)
        ])
        
        # Linear part (1st order)
        self.linear = nn.Linear(num_fields * embedding_dim, 1)
        
        # DNN part (high-order)
        dnn_input_dim = num_fields * embedding_dim
        layers = []
        prev_dim = dnn_input_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim),
                nn.Dropout(0.2)
            ])
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, 1))
        self.dnn = nn.Sequential(*layers)
    
    def forward(self, field_indices: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.
        
        field_indices: (batch_size, num_fields) - indices for each field
        """
        # Get embeddings for each field
        embedded = []
        for i, emb in enumerate(self.embeddings):
            embedded.append(emb(field_indices[:, i]))
        
        # Stack: (batch_size, num_fields, embedding_dim)
        embedded = torch.stack(embedded, dim=1)
        
        # FM: 2nd order interactions
        # sum_square - square_sum trick
        sum_embed = torch.sum(embedded, dim=1)
        sum_square = torch.sum(sum_embed ** 2, dim=-1, keepdim=True)
        square_sum = torch.sum(torch.sum(embedded ** 2, dim=-1), dim=-1, keepdim=True)
        fm_out = 0.5 * (sum_square - square_sum)
        
        # Flatten for linear and DNN
        flat = embedded.view(embedded.size(0), -1)
        
        # Linear (1st order)
        linear_out = self.linear(flat)
        
        # DNN (high-order)
        dnn_out = self.dnn(flat)
        
        # Combine
        logits = linear_out + fm_out + dnn_out
        return torch.sigmoid(logits).squeeze(-1)

class CTRCalibration:
    """
    Calibrate CTR predictions for accurate bidding.
    
    Predicted CTR should match actual CTR.
    """
    
    def __init__(self, num_bins: int = 10):
        self.num_bins = num_bins
        self.calibration_map = {}
    
    def fit(self, predictions: np.ndarray, labels: np.ndarray):
        """Fit calibration using Platt scaling or isotonic regression."""
        # Bin predictions
        bins = np.linspace(0, 1, self.num_bins + 1)
        
        for i in range(self.num_bins):
            mask = (predictions >= bins[i]) & (predictions < bins[i + 1])
            if np.sum(mask) > 0:
                predicted_ctr = predictions[mask].mean()
                actual_ctr = labels[mask].mean()
                self.calibration_map[i] = actual_ctr / (predicted_ctr + 1e-10)
    
    def calibrate(self, predictions: np.ndarray) -> np.ndarray:
        """Apply calibration."""
        bins = np.linspace(0, 1, self.num_bins + 1)
        calibrated = predictions.copy()
        
        for i in range(self.num_bins):
            mask = (predictions >= bins[i]) & (predictions < bins[i + 1])
            if i in self.calibration_map:
                calibrated[mask] *= self.calibration_map[i]
        
        return np.clip(calibrated, 0, 1)

print("\n=== CTR Prediction (DeepFM) ===")
model = DeepFM(num_fields=10, embedding_dim=8)

# Sample batch
batch = torch.randint(0, 1000, (32, 10))  # 10 categorical fields
ctr_pred = model(batch)
print(f"Predicted CTR: mean={ctr_pred.mean().item():.4f}, std={ctr_pred.std().item():.4f}")

---
# Part 5: System Design Interview Framework

## Template for ML System Design Questions

In [None]:
class MLSystemDesignTemplate:
    """
    Framework for answering ML system design questions.
    
    Time allocation (45 min interview):
    - Clarification: 5 min (10%)
    - High-level design: 10 min (20%)
    - Deep dive: 20 min (50%)
    - Evaluation: 10 min (20%)
    """
    
    @staticmethod
    def step1_clarify() -> Dict[str, List[str]]:
        """Questions to ask the interviewer."""
        return {
            'business': [
                'What is the primary business objective?',
                'What metrics matter most? (engagement, revenue, safety)',
                'Are there any constraints? (latency, cost, fairness)'
            ],
            'scale': [
                'How many users/items/requests?',
                'What is the expected QPS?',
                'Real-time or batch predictions?'
            ],
            'data': [
                'What data is available?',
                'How much labeled data do we have?',
                'What is the label distribution? (imbalanced?)'
            ]
        }
    
    @staticmethod
    def step2_high_level_design() -> Dict[str, str]:
        """Components of the system."""
        return {
            'data_pipeline': 'Data collection -> Processing -> Feature store',
            'training_pipeline': 'Feature engineering -> Model training -> Evaluation',
            'serving_pipeline': 'Feature retrieval -> Inference -> Post-processing',
            'monitoring': 'Metrics collection -> Drift detection -> Alerting'
        }
    
    @staticmethod
    def step3_deep_dive() -> Dict[str, List[str]]:
        """Deep dive topics."""
        return {
            'features': [
                'User features (demographics, history, preferences)',
                'Item features (content, popularity, recency)',
                'Context features (time, device, location)',
                'Cross features (user-item interactions)'
            ],
            'model': [
                'Model architecture (two-tower, transformer, GNN)',
                'Loss function (cross-entropy, focal, pairwise)',
                'Multi-task learning (multiple objectives)',
                'Training strategy (in-batch negatives, hard negatives)'
            ],
            'serving': [
                'Candidate generation (ANN, collaborative filtering)',
                'Ranking (heavy model on top-k candidates)',
                'Re-ranking (diversity, business rules)',
                'Caching (user/item embeddings, frequent predictions)'
            ]
        }
    
    @staticmethod
    def step4_evaluation() -> Dict[str, List[str]]:
        """Evaluation metrics and testing."""
        return {
            'offline_metrics': [
                'AUC-ROC, PR-AUC (classification)',
                'NDCG, MRR (ranking)',
                'Precision@k, Recall@k (retrieval)',
                'Calibration (predicted vs actual CTR)'
            ],
            'online_testing': [
                'A/B testing with statistical significance',
                'Interleaving experiments for ranking',
                'Holdout groups for long-term effects',
                'Guardrail metrics (safety, latency)'
            ],
            'monitoring': [
                'Feature drift (PSI, KS test)',
                'Prediction drift (distribution shift)',
                'Performance degradation (accuracy over time)',
                'Latency (P50, P95, P99)'
            ]
        }

# Print the framework
print("\n=== ML System Design Interview Framework ===")
template = MLSystemDesignTemplate()

print("\n1. CLARIFY (5 min)")
for category, questions in template.step1_clarify().items():
    print(f"  {category.upper()}:")
    for q in questions[:2]:
        print(f"    - {q}")

print("\n2. HIGH-LEVEL DESIGN (10 min)")
for component, description in template.step2_high_level_design().items():
    print(f"  {component}: {description}")

print("\n3. DEEP DIVE (20 min)")
for topic, points in template.step3_deep_dive().items():
    print(f"  {topic.upper()}: {points[0]}...")

print("\n4. EVALUATION (10 min)")
for category, metrics in template.step4_evaluation().items():
    print(f"  {category}: {metrics[0]}")

## Key Takeaways

### Recommendation Systems:
- Two-tower for candidate generation (embeddings + ANN)
- Multi-task ranking for final scoring
- In-batch negatives for efficient training

### Search Ranking:
- Query understanding is critical
- Learning-to-rank (pointwise, pairwise, listwise)
- NDCG is the primary metric

### Fraud Detection:
- Real-time feature engineering (velocity, device)
- Handle imbalanced data (focal loss, class weights)
- Precision-recall trade-off (not accuracy)

### CTR Prediction:
- Feature interactions matter (FM, DeepFM)
- Calibration for accurate bidding
- Multi-task learning for multiple objectives

## FAANG Interview Questions

**Q1: Design Instagram's Explore page.**
- Two-tower for candidate generation (user-post embeddings)
- Multi-task ranker (P(like), P(comment), P(share))
- Diversity injection (different content types)
- Cold-start for new users (popularity, demographics)

**Q2: Design YouTube search ranking.**
- Query understanding (intent, entities)
- Two-stage retrieval (inverted index + semantic)
- Learning-to-rank with NDCG optimization
- Personalization (watch history, preferences)

**Q3: Design fraud detection for Stripe.**
- Real-time features (velocity, device fingerprint)
- Graph features (fraud rings, connected accounts)
- Focal loss for imbalanced data
- Human-in-the-loop for edge cases

**Q4: What metrics would you use for a recommendation system?**
- Offline: Precision@k, Recall@k, NDCG, Hit Rate
- Online: CTR, watch time, engagement rate
- Business: Revenue, retention, DAU
- Guardrails: Diversity, freshness, safety