# üß† Classification Basics: Predictive Modeling with Spark MLlib

**Time to complete:** 45 minutes  
**Difficulty:** Intermediate  
**Prerequisites:** DataFrames, statistics basics, ML concepts

---

## üéØ Learning Objectives

By the end of this notebook, you will master:
- ‚úÖ **Classification fundamentals** - Binary vs multiclass
- ‚úÖ **Logistic regression** - Probability-based classification
- ‚úÖ **Decision trees** - Tree-based classification
- ‚úÖ **Random forests** - Ensemble classification
- ‚úÖ **Model evaluation** - Accuracy, precision, recall, F1-score
- ‚úÖ **Feature engineering** - Preparing data for ML
- ‚úÖ **Pipeline construction** - End-to-end ML workflows

**Spark MLlib makes distributed machine learning accessible!**

---

## üîç Understanding Classification

### What is Classification?

**Classification** is a supervised learning technique that predicts categorical labels for input data.

```
Input Features:     [age, income, education, credit_score]
Prediction Task:    Will this customer default on loan?
Output:            "Yes" (will default) or "No" (won't default)
```

### Types of Classification

1. **Binary Classification**: Two classes (Yes/No, True/False)
2. **Multiclass Classification**: Three or more classes (cat/dog/bird)
3. **Multilabel Classification**: Multiple labels per instance

**We'll focus on binary and multiclass classification in this notebook.**

### Real-World Classification Examples

- **Fraud Detection**: Transaction ‚Üí Fraudulent/legitimate
- **Spam Filtering**: Email ‚Üí Spam/not spam
- **Medical Diagnosis**: Symptoms ‚Üí Disease classification
- **Credit Scoring**: Application ‚Üí Approve/deny
- **Image Recognition**: Pixels ‚Üí Object categories
- **Sentiment Analysis**: Text ‚Üí Positive/negative/neutral

In [None]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
import pyspark.sql.functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType

spark = SparkSession.builder \
    .appName("MLlib_Classification_Basics") \
    .master("local[*]") \
    .getOrCreate()

print(f"‚úÖ Spark ready - Version: {spark.version}")
print("MLlib classification libraries imported")

# Enable MLlib optimizations
spark.conf.set("spark.ml.optimizer.enabled", "true")
spark.conf.set("spark.ml.decisionTree.maxBins", "32")
spark.conf.set("spark.ml.decisionTree.maxDepth", "10")

## üìä Preparing Classification Data

### Sample Dataset: Customer Churn Prediction

**We'll predict whether customers will churn (cancel their subscription) based on their usage patterns.**

In [None]:
# Create sample customer churn dataset
print("üìä CREATING SAMPLE DATASET")
print("=" * 50)

# Generate synthetic customer data
import random
random.seed(42)

customer_data = []
for i in range(10000):
    # Customer features
    age = random.randint(18, 80)
    monthly_charges = round(random.uniform(20, 200), 2)
    tenure_months = random.randint(1, 72)
    total_charges = round(monthly_charges * tenure_months * random.uniform(0.8, 1.2), 2)
    
    # Categorical features
    gender = random.choice(["Male", "Female"])
    contract_type = random.choice(["Month-to-month", "One year", "Two year"])
    internet_service = random.choice(["DSL", "Fiber optic", "No"])
    
    # Derived features
    avg_monthly_usage = random.randint(0, 1000)
    support_calls = random.randint(0, 10)
    
    # Churn probability (simplified logic)
    churn_risk = (
        (monthly_charges > 100) * 0.3 +
        (tenure_months < 12) * 0.4 +
        (contract_type == "Month-to-month") * 0.3 +
        (support_calls > 3) * 0.2 +
        random.uniform(0, 0.5)
    )
    
    churn = 1 if churn_risk > 0.7 else 0
    
    customer_data.append({
        "customer_id": f"CUST_{i:04d}",
        "age": age,
        "gender": gender,
        "monthly_charges": monthly_charges,
        "tenure_months": tenure_months,
        "total_charges": total_charges,
        "contract_type": contract_type,
        "internet_service": internet_service,
        "avg_monthly_usage": avg_monthly_usage,
        "support_calls": support_calls,
        "churn": churn
    })

# Create DataFrame
churn_df = spark.createDataFrame(customer_data)

print("Customer churn dataset created:")
print(f"Total customers: {churn_df.count():,}")
print(f"Churn rate: {churn_df.filter('churn = 1').count() / churn_df.count():.1%}")

print("\nDataset schema:")
churn_df.printSchema()

print("\nSample data:")
churn_df.show(5)

### Data Exploration

**Understanding your data is crucial before building ML models.**

In [None]:
# Explore the dataset
print("üîç DATA EXPLORATION")
print("=" * 50)

# Basic statistics
print("Target variable distribution:")
churn_df.groupBy("churn").count().show()

# Numeric feature statistics
numeric_cols = ["age", "monthly_charges", "tenure_months", "total_charges", "avg_monthly_usage", "support_calls"]
print("\nNumeric feature statistics:")
churn_df.select(numeric_cols).summary().show()

# Categorical feature distributions
categorical_cols = ["gender", "contract_type", "internet_service"]
print("\nCategorical feature distributions:")
for col in categorical_cols:
    print(f"\n{col}:")
    churn_df.groupBy(col).count().orderBy("count", ascending=False).show()

# Churn analysis by categories
print("\nChurn rate by contract type:")
churn_df.groupBy("contract_type").agg(
    F.count("*").alias("total"),
    F.sum("churn").alias("churned"),
    (F.sum("churn") / F.count("*") * 100).alias("churn_rate_%")
).show()

print("\nChurn rate by internet service:")
churn_df.groupBy("internet_service").agg(
    F.count("*").alias("total"),
    F.sum("churn").alias("churned"),
    (F.sum("churn") / F.count("*") * 100).alias("churn_rate_%")
).show()

## üîß Feature Engineering for Classification

### Preparing Features for ML Models

**ML algorithms require numerical features. We need to transform categorical variables and prepare the data.**

In [None]:
# Feature engineering pipeline
print("üîß FEATURE ENGINEERING")
print("=" * 50)

# 1. String indexing for categorical variables
gender_indexer = StringIndexer(inputCol="gender", outputCol="gender_index")
contract_indexer = StringIndexer(inputCol="contract_type", outputCol="contract_index")
internet_indexer = StringIndexer(inputCol="internet_service", outputCol="internet_index")

# 2. One-hot encoding for categorical variables
gender_encoder = OneHotEncoder(inputCol="gender_index", outputCol="gender_vec")
contract_encoder = OneHotEncoder(inputCol="contract_index", outputCol="contract_vec")
internet_encoder = OneHotEncoder(inputCol="internet_index", outputCol="internet_vec")

# 3. Assemble all features into a single vector
feature_cols = [
    "age", "monthly_charges", "tenure_months", "total_charges",
    "avg_monthly_usage", "support_calls",
    "gender_vec", "contract_vec", "internet_vec"
]

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")

# Create preprocessing pipeline
preprocessing_pipeline = Pipeline(stages=[
    gender_indexer, contract_indexer, internet_indexer,
    gender_encoder, contract_encoder, internet_encoder,
    assembler
])

# Fit and transform the data
preprocessing_model = preprocessing_pipeline.fit(churn_df)
processed_df = preprocessing_model.transform(churn_df)

print("Feature engineering completed")
print("\nProcessed data sample:")
processed_df.select("customer_id", "features", "churn").show(5, truncate=False)

# Check feature vector size
print(f"\nFeature vector size: {len(feature_cols)} features")
print("(Including one-hot encoded categorical variables)")

### Train/Test Split

**Split data into training and testing sets for model evaluation.**

In [None]:
# Train/test split
print("üéØ TRAIN/TEST SPLIT")
print("=" * 50)

# Split the processed data
train_df, test_df = processed_df.randomSplit([0.8, 0.2], seed=42)

print(f"Training set: {train_df.count():,} samples")
print(f"Testing set: {test_df.count():,} samples")
print(f"Split ratio: {train_df.count()/processed_df.count():.1%} train, {test_df.count()/processed_df.count():.1%} test")

# Check class distribution in splits
print("\nClass distribution:")
print("Training set:")
train_df.groupBy("churn").count().show()
print("Testing set:")
test_df.groupBy("churn").count().show()

# Check for data quality
print("\nData quality checks:")
print(f"Training null features: {train_df.filter('features is null').count()}")
print(f"Testing null features: {test_df.filter('features is null').count()}")
print(f"Training null labels: {train_df.filter('churn is null').count()}")
print(f"Testing null labels: {test_df.filter('churn is null').count()}")

## üìà Logistic Regression: Probability-Based Classification

### Understanding Logistic Regression

**Logistic regression** predicts the probability of a binary outcome using a logistic (sigmoid) function.

```
Linear Combination: z = w‚ÇÅx‚ÇÅ + w‚ÇÇx‚ÇÇ + ... + w‚Çôx‚Çô + b
Probability: p = 1 / (1 + e^(-z))
Prediction: class = 1 if p > 0.5, else 0
```

**Advantages:**
- Interpretable coefficients
- Probabilistic outputs
- Fast training and prediction
- Works well with linear relationships

In [None]:
# Logistic Regression model
print("üìà LOGISTIC REGRESSION")
print("=" * 50)

# Create and train logistic regression model
lr = LogisticRegression(
    featuresCol="features",
    labelCol="churn",
    maxIter=100,
    regParam=0.01,  # L2 regularization
    elasticNetParam=0.0,  # L2 only
    family="binomial"  # Binary classification
)

# Train the model
print("Training Logistic Regression model...")
lr_model = lr.fit(train_df)

print("‚úÖ Model trained successfully")

# Make predictions
lr_predictions = lr_model.transform(test_df)

print("\nPredictions sample:")
lr_predictions.select("customer_id", "churn", "prediction", "probability").show(10)

# Model coefficients and intercept
print(f"\nModel coefficients shape: {lr_model.coefficients.size}")
print(f"Model intercept: {lr_model.intercept:.4f}")

# Feature importance (absolute coefficients)
feature_importance = list(zip(feature_cols, lr_model.coefficients.toArray()))
feature_importance.sort(key=lambda x: abs(x[1]), reverse=True)

print("\nTop 5 most important features:")
for feature, coef in feature_importance[:5]:
    print(f"  {feature}: {coef:.4f}")

## üå≥ Decision Trees: Rule-Based Classification

### Understanding Decision Trees

**Decision trees** recursively split data based on feature values to create classification rules.

```
Root Node
‚îú‚îÄ‚îÄ Feature A > threshold?
‚îÇ   ‚îú‚îÄ‚îÄ Yes ‚Üí Class 1
‚îÇ   ‚îî‚îÄ‚îÄ No ‚Üí Feature B > threshold?
‚îÇ       ‚îú‚îÄ‚îÄ Yes ‚Üí Class 0
‚îÇ       ‚îî‚îÄ‚îÄ No ‚Üí Class 1
```

**Advantages:**
- Easy to interpret and visualize
- Handles both numerical and categorical features
- No need for feature scaling
- Can capture non-linear relationships

In [None]:
# Decision Tree model
print("üå≥ DECISION TREE CLASSIFIER")
print("=" * 50)

# Create and train decision tree model
dt = DecisionTreeClassifier(
    featuresCol="features",
    labelCol="churn",
    maxDepth=5,  # Limit tree depth to prevent overfitting
    maxBins=32,  # Number of bins for continuous features
    impurity="gini",  # Split quality measure
    seed=42
)

# Train the model
print("Training Decision Tree model...")
dt_model = dt.fit(train_df)

print("‚úÖ Decision Tree trained successfully")

# Make predictions
dt_predictions = dt_model.transform(test_df)

print("\nPredictions sample:")
dt_predictions.select("customer_id", "churn", "prediction").show(10)

# Tree information
print(f"\nTree depth: {dt_model.depth}")
print(f"Number of nodes: {dt_model.numNodes}")
print(f"Feature importances shape: {dt_model.featureImportances.size}")

# Feature importances
feature_importance_dt = list(zip(feature_cols, dt_model.featureImportances.toArray()))
feature_importance_dt.sort(key=lambda x: x[1], reverse=True)

print("\nTop 5 most important features:")
for feature, importance in feature_importance_dt[:5]:
    print(f"  {feature}: {importance:.4f}")

# Print tree structure (simplified)
print("\nDecision tree structure preview:")
print(dt_model.toDebugString[:500] + "...")

## üå≤ Random Forest: Ensemble Classification

### Understanding Random Forests

**Random forests** combine multiple decision trees to improve accuracy and reduce overfitting.

```
Training Process:
1. Create multiple decision trees
2. Each tree trained on random subset of data
3. Each tree uses random subset of features
4. Final prediction = majority vote of all trees
```

**Advantages:**
- High accuracy
- Reduced overfitting
- Feature importance estimation
- Handles missing values well
- Parallel training

In [None]:
# Random Forest model
print("üå≤ RANDOM FOREST CLASSIFIER")
print("=" * 50)

# Create and train random forest model
rf = RandomForestClassifier(
    featuresCol="features",
    labelCol="churn",
    numTrees=50,  # Number of trees in the forest
    maxDepth=6,   # Maximum depth of each tree
    maxBins=32,  # Number of bins for continuous features
    minInstancesPerNode=5,
    seed=42
)

# Train the model
print("Training Random Forest model...")
rf_model = rf.fit(train_df)

print("‚úÖ Random Forest trained successfully")

# Make predictions
rf_predictions = rf_model.transform(test_df)

print("\nPredictions sample:")
rf_predictions.select("customer_id", "churn", "prediction").show(10)

# Forest information
print(f"\nNumber of trees: {rf_model.getNumTrees}")
print(f"Total number of nodes: {rf_model.totalNumNodes}")

# Feature importances
feature_importance_rf = list(zip(feature_cols, rf_model.featureImportances.toArray()))
feature_importance_rf.sort(key=lambda x: x[1], reverse=True)

print("\nTop 5 most important features:")
for feature, importance in feature_importance_rf[:5]:
    print(f"  {feature}: {importance:.4f}")

# Individual tree information
print("\nForest composition:")
tree_info = rf_model.treeWeights  # Weights of each tree
print(f"All trees have equal weight: {len(set(tree_info)) == 1}")
print(f"Average tree weight: {sum(tree_info)/len(tree_info):.4f}")

## üìä Model Evaluation and Comparison

### Classification Metrics

**Evaluating classification models requires multiple metrics beyond accuracy.**

In [None]:
# Model evaluation
print("üìä MODEL EVALUATION")
print("=" * 50)

# Create evaluators
binary_evaluator = BinaryClassificationEvaluator(
    labelCol="churn",
    rawPredictionCol="rawPrediction",
    metricName="areaUnderROC"
)

multiclass_evaluator = MulticlassClassificationEvaluator(
    labelCol="churn",
    predictionCol="prediction"
)

# Evaluate all models
models = {
    "Logistic Regression": lr_predictions,
    "Decision Tree": dt_predictions,
    "Random Forest": rf_predictions
}

results = {}

for model_name, predictions in models.items():
    print(f"\nüîç Evaluating {model_name}:")
    
    # Confusion matrix
    confusion_matrix = predictions.groupBy("churn", "prediction").count()
    print("Confusion Matrix:")
    confusion_matrix.show()
    
    # Calculate metrics manually for clarity
    metrics = predictions.select("churn", "prediction").groupBy().agg(
        # Accuracy
        (F.sum(F.when(F.col("churn") == F.col("prediction"), 1).otherwise(0)) / F.count("*")).alias("accuracy"),
        
        # Precision for class 1 (churn)
        (F.sum(F.when((F.col("prediction") == 1) & (F.col("churn") == 1), 1).otherwise(0)) /
         F.sum(F.when(F.col("prediction") == 1, 1).otherwise(0))).alias("precision"),
        
        # Recall for class 1 (churn)
        (F.sum(F.when((F.col("prediction") == 1) & (F.col("churn") == 1), 1).otherwise(0)) /
         F.sum(F.when(F.col("churn") == 1, 1).otherwise(0))).alias("recall")
    ).collect()[0]
    
    # AUC using evaluator
    auc = binary_evaluator.evaluate(predictions)
    
    # F1 Score using evaluator
    f1 = multiclass_evaluator.evaluate(predictions, {multiclass_evaluator.metricName: "f1"})
    
    results[model_name] = {
        "accuracy": metrics.accuracy,
        "precision": metrics.precision,
        "recall": metrics.recall,
        "f1_score": f1,
        "auc": auc
    }
    
    print(f"  Accuracy: {metrics.accuracy:.4f}")
    print(f"  Precision: {metrics.precision:.4f}")
    print(f"  Recall: {metrics.recall:.4f}")
    print(f"  F1 Score: {f1:.4f}")
    print(f"  AUC: {auc:.4f}")

# Model comparison
print("\nüèÜ MODEL COMPARISON:")
print("Metric" + "\t" + "\t".join(f"{model[:8]:<8}" for model in results.keys()))
print("-" * 60)

metrics_to_show = ["accuracy", "precision", "recall", "f1_score", "auc"]
for metric in metrics_to_show:
    values = [f"{results[model][metric]:.4f}" for model in results.keys()]
    print(f"{metric.capitalize()}\t" + "\t".join(values))

# Find best model for each metric
print("\nüèÖ BEST MODELS:")
for metric in metrics_to_show:
    best_model = max(results.keys(), key=lambda m: results[m][metric])
    best_value = results[best_model][metric]
    print(f"  {metric.capitalize()}: {best_model} ({best_value:.4f})")

## üîß Building Complete ML Pipelines

### End-to-End Classification Pipeline

**Pipelines combine preprocessing, feature engineering, and model training into a single workflow.**

In [None]:
# Complete ML pipeline
print("üîß COMPLETE ML PIPELINE")
print("=" * 50)

# Create a complete pipeline with preprocessing + model
complete_pipeline = Pipeline(stages=[
    # Preprocessing stages
    gender_indexer,
    contract_indexer, 
    internet_indexer,
    gender_encoder,
    contract_encoder,
    internet_encoder,
    assembler,
    
    # Model stage
    RandomForestClassifier(
        featuresCol="features",
        labelCol="churn",
        numTrees=30,
        maxDepth=5,
        seed=42
    )
])

# Train the complete pipeline
print("Training complete pipeline...")
pipeline_model = complete_pipeline.fit(churn_df)  # Note: training on full dataset for demo

print("‚úÖ Pipeline trained successfully")

# Make predictions with the pipeline
pipeline_predictions = pipeline_model.transform(churn_df)

print("\nPipeline predictions sample:")
pipeline_predictions.select("customer_id", "churn", "prediction").show(10)

# Pipeline stages
print(f"\nPipeline has {len(pipeline_model.stages)} stages:")
for i, stage in enumerate(pipeline_model.stages):
    stage_name = stage.__class__.__name__
    if hasattr(stage, "getNumTrees"):
        stage_name += f" ({stage.getNumTrees} trees)"
    print(f"  Stage {i+1}: {stage_name}")

# Evaluate pipeline performance
pipeline_auc = binary_evaluator.evaluate(pipeline_predictions)
pipeline_f1 = multiclass_evaluator.evaluate(pipeline_predictions, {multiclass_evaluator.metricName: "f1"})

print(f"\nPipeline Performance:")
print(f"  AUC: {pipeline_auc:.4f}")
print(f"  F1 Score: {pipeline_f1:.4f}")

# Save the pipeline
pipeline_path = "/tmp/churn_prediction_pipeline"
pipeline_model.write().overwrite().save(pipeline_path)
print(f"\nPipeline saved to: {pipeline_path}")

# Load the pipeline (demonstration)
# loaded_pipeline = PipelineModel.load(pipeline_path)
# loaded_predictions = loaded_pipeline.transform(new_data)
print("Pipeline can be loaded and used for new predictions")

## ‚öôÔ∏è Hyperparameter Tuning

### Optimizing Model Parameters

**Hyperparameter tuning** finds the best model configuration through systematic search.**

In [None]:
# Hyperparameter tuning
print("‚öôÔ∏è HYPERPARAMETER TUNING")
print("=" * 50)

# Create parameter grid for Random Forest
param_grid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [10, 20, 30]) \
    .addGrid(rf.maxDepth, [3, 5, 7]) \
    .build()

print(f"Parameter grid size: {len(param_grid)} combinations")
print("\nParameter combinations:")
for i, params in enumerate(param_grid):
    num_trees = params[rf.numTrees]
    max_depth = params[rf.maxDepth]
    print(f"  {i+1}. numTrees={num_trees}, maxDepth={max_depth}")

# Create cross-validator
cross_validator = CrossValidator(
    estimator=rf,
    estimatorParamMaps=param_grid,
    evaluator=binary_evaluator,
    numFolds=3,  # 3-fold cross-validation
    seed=42
)

# Note: Full cross-validation can be slow, so we'll demonstrate with smaller data
# In production, use a representative sample for tuning
tuning_sample = train_df.sample(0.3, seed=42)  # 30% sample for faster tuning

print(f"\nTuning on sample: {tuning_sample.count()} records")

# Perform cross-validation (this may take a few minutes)
print("Performing cross-validation...")
cv_model = cross_validator.fit(tuning_sample)

print("‚úÖ Cross-validation completed")

# Best model parameters
best_rf_model = cv_model.bestModel
best_params = cv_model.getEstimatorParamMaps()[cv_model.avgMetrics.index(max(cv_model.avgMetrics))]

print(f"\nBest parameters found:")
print(f"  numTrees: {best_params[rf.numTrees]}")
print(f"  maxDepth: {best_params[rf.maxDepth]}")
print(f"  Best AUC: {max(cv_model.avgMetrics):.4f}")

# Evaluate best model on full test set
tuned_predictions = best_rf_model.transform(test_df)
tuned_auc = binary_evaluator.evaluate(tuned_predictions)
tuned_f1 = multiclass_evaluator.evaluate(tuned_predictions, {multiclass_evaluator.metricName: "f1"})

print(f"\nTuned model performance on full test set:")
print(f"  AUC: {tuned_auc:.4f}")
print(f"  F1 Score: {tuned_f1:.4f}")

# Compare with default model
default_auc = binary_evaluator.evaluate(rf_predictions)
improvement = (tuned_auc - default_auc) / default_auc * 100

print(f"\nImprovement over default: {improvement:.1f}%")

## üéØ Production Deployment Considerations

### Making Models Production-Ready

**Production ML systems require reliability, monitoring, and maintainability.**

In [None]:
# Production considerations
print("üéØ PRODUCTION DEPLOYMENT")
print("=" * 50)

# 1. Model serialization and versioning
print("1. MODEL SERIALIZATION & VERSIONING")
model_version = "1.0.0"
model_path = f"/tmp/churn_model_v{model_version}"

# Save the best model
best_rf_model.write().overwrite().save(model_path)
print(f"Model saved to: {model_path}")

# Save preprocessing pipeline separately
preprocessing_path = f"/tmp/churn_preprocessing_v{model_version}"
preprocessing_model.write().overwrite().save(preprocessing_path)
print(f"Preprocessing pipeline saved to: {preprocessing_path}")

# 2. Model metadata
print("\n2. MODEL METADATA")
model_metadata = {
    "model_type": "RandomForestClassifier",
    "version": model_version,
    "training_date": "2024-01-15",
    "features": feature_cols,
    "target": "churn",
    "metrics": {
        "auc": tuned_auc,
        "f1_score": tuned_f1,
        "accuracy": multiclass_evaluator.evaluate(tuned_predictions, {multiclass_evaluator.metricName: "accuracy"})
    },
    "hyperparameters": {
        "numTrees": best_params[rf.numTrees],
        "maxDepth": best_params[rf.maxDepth],
        "maxBins": 32
    },
    "data_info": {
        "training_samples": train_df.count(),
        "test_samples": test_df.count(),
        "feature_count": len(feature_cols)
    }
}

# Save metadata (in production, use a proper metadata store)
import json
metadata_path = f"/tmp/churn_model_metadata_v{model_version}.json"
with open(metadata_path, 'w') as f:
    json.dump(model_metadata, f, indent=2, default=str)

print(f"Model metadata saved to: {metadata_path}")

# 3. Prediction function for production use
print("\n3. PRODUCTION PREDICTION FUNCTION")

def predict_churn(customer_data):
    """
    Production prediction function for customer churn
    
    Args:
        customer_data: DataFrame with customer features
    
    Returns:
        DataFrame with churn predictions
    """
    try:
        # Load models (in production, load once at startup)
        from pyspark.ml.classification import RandomForestClassificationModel
        model = RandomForestClassificationModel.load(model_path)
        
        # Load preprocessing pipeline
        from pyspark.ml import PipelineModel
        preprocessing = PipelineModel.load(preprocessing_path)
        
        # Process and predict
        processed_data = preprocessing.transform(customer_data)
        predictions = model.transform(processed_data)
        
        # Return relevant columns
        result = predictions.select(
            "customer_id",
            "prediction",
            "probability"
        )
        
        return result
        
    except Exception as e:
        print(f"Prediction error: {e}")
        return None

# Test production function
test_customers = churn_df.limit(5)
print("\nTesting production prediction function:")
# predictions = predict_churn(test_customers)  # Would work with saved models
print("Production function ready for deployment")

# 4. Monitoring and alerting
print("\n4. MONITORING & ALERTING")
monitoring_guidelines = [
    "Monitor prediction latency and throughput",
    "Track model performance drift over time",
    "Set up alerts for prediction failures",
    "Log feature distributions for drift detection",
    "Implement A/B testing for model updates",
    "Regular model retraining schedule"
]

print("Production monitoring guidelines:")
for guideline in monitoring_guidelines:
    print(f"  ‚úì {guideline}")

## üéØ Key Takeaways

### Classification Concepts Mastered:

1. **Binary Classification**: Predicting churn (yes/no)
2. **Feature Engineering**: Categorical encoding, feature assembly
3. **Model Algorithms**: Logistic regression, decision trees, random forests
4. **Model Evaluation**: Accuracy, precision, recall, F1, AUC
5. **ML Pipelines**: End-to-end preprocessing + modeling
6. **Hyperparameter Tuning**: Cross-validation, parameter search
7. **Production Deployment**: Model serialization, monitoring, versioning

### Model Performance Comparison:

| Algorithm | Strengths | Weaknesses | Use Case |
|-----------|-----------|------------|----------|
| **Logistic Regression** | Interpretable, fast, probabilistic | Assumes linear relationships | Baseline, interpretable models |
| **Decision Trees** | Interpretable, handles non-linear | Prone to overfitting | Small datasets, interpretability |
| **Random Forest** | High accuracy, robust, feature importance | Less interpretable, slower | Production ML, high accuracy |

### Production ML Checklist:

- ‚úÖ **Data Quality**: Validation, missing value handling
- ‚úÖ **Feature Engineering**: Scaling, encoding, selection
- ‚úÖ **Model Selection**: Algorithm choice, hyperparameter tuning
- ‚úÖ **Evaluation**: Multiple metrics, cross-validation
- ‚úÖ **Deployment**: Serialization, versioning, monitoring
- ‚úÖ **Maintenance**: Drift detection, retraining schedule

### Business Impact:

**Well-implemented classification can:**
- **Reduce churn**: Target at-risk customers for retention
- **Increase revenue**: Personalize marketing and offers
- **Optimize operations**: Predict demand and resource needs
- **Prevent fraud**: Identify suspicious transactions
- **Improve customer experience**: Proactive service delivery

---

## üöÄ Next Steps

Now that you understand classification basics, you're ready for:

1. **Regression Techniques** - Predicting continuous values
2. **Clustering Algorithms** - Unsupervised learning patterns
3. **Recommendation Systems** - Collaborative filtering
4. **Feature Engineering** - Advanced feature creation
5. **Model Evaluation** - Deep dive into metrics and validation
6. **ML Pipeline Design** - Complex workflow orchestration

**Classification is the foundation of supervised machine learning!**

---

**üéâ Congratulations! You now master classification with Spark MLlib!**