# Advanced Machine Learning Models for Imbalanced Classification in PySpark

This notebook implements multiple classification models using PySpark with techniques specifically designed for handling class imbalance.

## Table of Contents
1. [Configuration and Setup](#1.-Configuration-and-Setup)
2. [Importing Libraries](#2.-Importing-Libraries)
3. [Loading and Exploring Data](#3.-Loading-and-Exploring-Data)
4. [Class Imbalance Handling](#4.-Class-Imbalance-Handling)
5. [Helper Functions](#5.-Helper-Functions)
    - [5.1 Random Search Function](#5.1-Random-Search-Function)
    - [5.2 Evaluation Functions](#5.2-Evaluation-Functions)
6. [Model Training](#6.-Model-Training)
    - [6.1 Logistic Regression](#6.1-Logistic-Regression)
    - [6.2 Random Forest](#6.2-Random-Forest)
    - [6.3 Gradient Boosted Trees](#6.3-Gradient-Boosted-Trees-(GBDT))
    - [6.4 Multilayer Perceptron](#6.4-Multilayer-Perceptron-(MLP))
7. [Model Comparison](#7.-Model-Comparison)
8. [Threshold Optimization](#8.-Threshold-Optimization)
9. [Test Set Evaluation](#9.-Test-Set-Evaluation-with-Best-Model)
10. [Conclusion](#10.-Conclusion)

## 1. Configuration and Setup

In [None]:
# Model selection configuration - Set to True to run, False to skip
RUN_MODELS = {
    'logistic_regression': True,
    'random_forest': True,
    'gbdt': True,
    'mlp': True
}

# Class imbalance handling configuration
IMBALANCE_CONFIG = {
    # Balance method: None, 'class_weight', 'undersample', 'oversample', or 'both'
    'method': 'both',
    
    # Undersampling configuration (if method = 'undersample' or 'both')
    'undersample_ratio': 0.4,  # Higher = keep more majority class samples
    
    # Oversampling configuration (if method = 'oversample' or 'both')
    'oversample_ratio': 0.7,  # Target ratio of minority to majority class
    
    # Apply threshold adjustment after model training
    'adjust_threshold': True,
    
    # Enable class weights (for models that support it)
    'use_class_weights': True
}

# Random search configuration
RANDOM_SEARCH_CONFIG = {
    'num_folds': 3,      # Number of folds for cross-validation
    'parallelism': 1     # Reduced to 1 to avoid resource issues
}

# File paths configuration
DATA_PATHS = {
    'train': "dbfs:/FileStore/tables/train_df-2.csv",
    'val': "dbfs:/FileStore/tables/val_df.csv",
    'test': "dbfs:/FileStore/tables/test_df-2.csv"
}

## 2. Importing Libraries

In [None]:
# Importing libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
from math import ceil
import time

# PySpark imports
from pyspark.sql import SparkSession  
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number, when, lit, count, lag, expr, udf, rand
from pyspark.sql.types import DoubleType, ArrayType, FloatType

# ML imports for classification
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier, MultilayerPerceptronClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator
from pyspark.ml.feature import VectorAssembler
from pyspark.mllib.evaluation import MulticlassMetrics

# Scikit-learn imports for metrics
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve

# Initialize Spark session
spark = SparkSession.builder.appName("ML Models Spark").getOrCreate()

## 3. Loading and Exploring Data

In [None]:
def load_and_preprocess_data(file_paths):
    """Load and preprocess data from specified file paths.
    
    Args:
        file_paths (dict): Dictionary with keys 'train', 'val', 'test' and file path values
        
    Returns:
        tuple: Preprocessed train, validation, and test data
    """
    # Load data
    train_data = spark.read.csv(file_paths['train'], header=True, inferSchema=True)
    val_data = spark.read.csv(file_paths['val'], header=True, inferSchema=True)
    test_data = spark.read.csv(file_paths['test'], header=True, inferSchema=True)
    
    # Select feature columns (all except 'label', 'time', and 'file')
    feature_cols = [col for col in train_data.columns if col not in ['label', 'time', 'file']]
    
    # Assemble features into a single vector column
    assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
    train_data = assembler.transform(train_data).select("features", "label")
    val_data = assembler.transform(val_data).select("features", "label")
    test_data = assembler.transform(test_data).select("features", "label")
    
    print(f"Data loaded and preprocessed:")
    print(f"  - Training samples: {train_data.count()}")
    print(f"  - Validation samples: {val_data.count()}")
    print(f"  - Test samples: {test_data.count()}")
    
    return train_data, val_data, test_data

def analyze_class_distribution(dataframe, title="Class Distribution"):
    """Analyze and visualize class distribution in the dataset.
    
    Args:
        dataframe: PySpark DataFrame with 'label' column
        title: Title for the analysis
    
    Returns:
        dict: Class distribution statistics
    """
    # Count samples by class
    class_counts = dataframe.groupBy("label").count().orderBy("label").collect()
    
    # Calculate percentages
    total_count = dataframe.count()
    class_stats = {}
    
    print(f"\n{title}:")
    print("Class\tCount\tPercentage")
    print("-" * 30)
    
    classes = []
    counts = []
    percentages = []
    
    for row in class_counts:
        class_label = row["label"]
        count = row["count"]
        percentage = (count / total_count) * 100
        
        classes.append(class_label)
        counts.append(count)
        percentages.append(percentage)
        
        class_stats[class_label] = {"count": count, "percentage": percentage}
        print(f"{class_label}\t{count}\t{percentage:.2f}%")
    
    # Calculate imbalance ratio (majority / minority)
    majority_count = max(counts)
    minority_count = min(counts)
    imbalance_ratio = majority_count / minority_count
    print(f"\nImbalance Ratio (majority/minority): {imbalance_ratio:.2f}")
    
    # Visualize class distribution
    plt.figure(figsize=(10, 6))
    plt.bar(classes, counts)
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.title(title)
    
    # Add count labels on top of bars
    for i, count in enumerate(counts):
        plt.text(classes[i], count + (0.01 * majority_count), f"{count}\n({percentages[i]:.2f}%)", 
                 ha='center', va='bottom')
    
    plt.xticks(classes)
    plt.tight_layout()
    plt.show()
    
    return class_stats

# Load and preprocess the data
train_data, val_data, test_data = load_and_preprocess_data(DATA_PATHS)

# Display a few samples
print("\nSample of training data:")
train_data.show(3)

# Analyze class distribution
train_class_stats = analyze_class_distribution(train_data, "Training Data Class Distribution")
val_class_stats = analyze_class_distribution(val_data, "Validation Data Class Distribution")
test_class_stats = analyze_class_distribution(test_data, "Test Data Class Distribution")

# Count classes
num_classes = len(train_class_stats)
print(f"\nNumber of classes: {num_classes}")

## 4. Class Imbalance Handling

In [None]:
def handle_class_imbalance(train_data, class_stats, config):
    """Apply class imbalance handling techniques to the training data.
    
    Args:
        train_data: PySpark DataFrame with 'features' and 'label' columns
        class_stats: Dictionary with class distribution statistics
        config: Configuration dictionary for imbalance handling
        
    Returns:
        DataFrame: Processed training data
        dict: Class weights dictionary (if enabled)
    """
    method = config.get('method', None)
    processed_data = train_data
    
    # Calculate class weights based on inverse frequency
    total_samples = sum(stats["count"] for stats in class_stats.values())
    class_weights = {}
    
    for class_label, stats in class_stats.items():
        weight = total_samples / (num_classes * stats["count"])
        class_weights[class_label] = weight
    
    print("\nClass weights:")
    for class_label, weight in class_weights.items():
        print(f"Class {class_label}: {weight:.4f}")
    
    # Apply class weight column if enabled
    if config.get('use_class_weights', False):
        print("\nApplying class weights as a column in the dataset...")
        # Create a weight column based on class
        weight_expr = None
        for label, weight in class_weights.items():
            if weight_expr is None:
                weight_expr = when(col("label") == label, lit(weight))
            else:
                weight_expr = weight_expr.when(col("label") == label, lit(weight))
        
        processed_data = processed_data.withColumn("weight", weight_expr)
    
    # Apply resampling methods
    if method in ['undersample', 'both']:
        print("\nApplying undersampling to majority class...")
        
        # Find majority and minority classes
        class_counts = [(label, stats["count"]) for label, stats in class_stats.items()]
        class_counts.sort(key=lambda x: x[1], reverse=True)
        majority_class = class_counts[0][0]
        majority_count = class_counts[0][1]
        
        # Determine fraction to sample from majority class
        minority_counts = [count for label, count in class_counts[1:]]  
        avg_minority_count = sum(minority_counts) / len(minority_counts)
        undersampling_fraction = min(1.0, (avg_minority_count / majority_count) * config.get('undersample_ratio', 1.0))
        
        # Apply undersampling
        majority_samples = processed_data.filter(col("label") == majority_class).sample(False, undersampling_fraction)
        other_samples = processed_data.filter(col("label") != majority_class)
        processed_data = majority_samples.union(other_samples)
        
        print(f"Undersampled majority class (class {majority_class}) with fraction: {undersampling_fraction:.4f}")
        
    if method in ['oversample', 'both']:
        print("\nApplying oversampling to minority classes...")
        
        # Find counts after potential undersampling
        current_counts = processed_data.groupBy("label").count().collect()
        class_counts = {row["label"]: row["count"] for row in current_counts}
        
        # Sort classes by count
        sorted_classes = sorted(class_counts.items(), key=lambda x: x[1])
        minority_class = sorted_classes[0][0]
        minority_count = sorted_classes[0][1]
        majority_class = sorted_classes[-1][0]
        majority_count = sorted_classes[-1][1]
        
        # Temporary result to build up with oversampled data
        result = processed_data
        
        # Oversample all classes except the majority
        target_count = int(majority_count * config.get('oversample_ratio', 0.6))
        
        for class_label, count in sorted_classes[:-1]:  # Skip the majority class
            if count < target_count:
                # Calculate number of times to repeat (at least 1)
                multiplier = max(1, int(target_count / count))
                remainder_fraction = (target_count % count) / count
                
                # Get the minority samples
                minority_samples = processed_data.filter(col("label") == class_label)
                
                # Apply oversampling by repeating the data
                oversampled = minority_samples
                for _ in range(multiplier - 1):  # -1 because we already have one copy
                    oversampled = oversampled.union(minority_samples)
                
                # Add the remainder using sampling
                if remainder_fraction > 0:
                    remainder = minority_samples.sample(True, remainder_fraction)
                    oversampled = oversampled.union(remainder)
                
                # Remove original instances of this class
                result = result.filter(col("label") != class_label).union(oversampled)
                
                print(f"Oversampled class {class_label} from {count} to approximately {target_count} samples")
        
        processed_data = result
    
    # Check the resulting class distribution
    if method is not None:
        print("\nClass distribution after imbalance handling:")
        rebalanced_counts = processed_data.groupBy("label").count().orderBy("label").collect()
        
        for row in rebalanced_counts:
            print(f"Class {row['label']}: {row['count']} samples")
    
    return processed_data, class_weights

# Apply class imbalance handling to training data
if IMBALANCE_CONFIG['method'] is not None:
    print("\nApplying class imbalance handling techniques...")
    balanced_train_data, class_weights = handle_class_imbalance(train_data, train_class_stats, IMBALANCE_CONFIG)
    
    # Replace original training data with balanced version
    train_data = balanced_train_data
else:
    print("\nNo class imbalance handling techniques applied.")
    class_weights = None

## 5. Helper Functions

### 5.1 Random Search Function

In [None]:
def perform_random_search(model, param_grid, train_data, val_data, num_folds=3, parallelism=1, 
                          use_f1=True, class_weights=None):
    """Perform random search for hyperparameter tuning of a model.
    
    Args:
        model: Machine learning model instance
        param_grid: Parameter grid for random search
        train_data: Training data DataFrame
        val_data: Validation data DataFrame
        num_folds: Number of cross-validation folds
        parallelism: Number of parallel tasks
        use_f1: Whether to use F1 score (True) or AUC (False) for evaluation
        class_weights: Dictionary of class weights to use
        
    Returns:
        tuple: (best_model, best_params, train_predictions, val_predictions, train_f1, val_f1)
    """
    # Initialize the evaluator based on chosen metric
    if use_f1:
        evaluator = MulticlassClassificationEvaluator(
            labelCol='label', 
            predictionCol='prediction', 
            metricName='f1'
        )
    else:
        # For binary classification or when we care more about ranking
        evaluator = BinaryClassificationEvaluator(
            labelCol='label',
            rawPredictionCol='rawPrediction',
            metricName='areaUnderPR'  # Area under precision-recall curve is better for imbalanced data
        )
    
    # Initialize CrossValidator for hyperparameter tuning
    cv = CrossValidator(
        estimator=model,
        estimatorParamMaps=param_grid,
        evaluator=evaluator,
        numFolds=num_folds,
        parallelism=parallelism
    )
    
    # Start timing
    start_time = time.time()
    
    # Fit the cross-validator to the training data
    print("Training model with random search...")
    cv_model = cv.fit(train_data)
    
    # End timing
    training_time = time.time() - start_time
    print(f"Training completed in {training_time:.2f} seconds")
    
    # Extract the best model
    best_model = cv_model.bestModel
    
    # Get best parameters
    best_params = {}
    for param in best_model.extractParamMap():
        param_name = param.name
        param_value = best_model.getOrDefault(param)
        best_params[param_name] = param_value
    
    # Make predictions with the best model
    train_predictions = best_model.transform(train_data)
    val_predictions = best_model.transform(val_data)
    
    # Use F1 for final evaluation to be consistent
    f1_evaluator = MulticlassClassificationEvaluator(
        labelCol='label', 
        predictionCol='prediction', 
        metricName='f1'
    )
    
    # Calculate F1 scores
    train_f1 = f1_evaluator.evaluate(train_predictions)
    val_f1 = f1_evaluator.evaluate(val_predictions)
    
    return best_model, best_params, train_predictions, val_predictions, train_f1, val_f1

# Special function for MLP which needs custom random search due to layers parameter
def perform_mlp_random_search(train_data, val_data, num_features, num_classes):
    """Perform custom random search for MLP hyperparameter tuning.
    
    Args:
        train_data: Training data DataFrame
        val_data: Validation data DataFrame
        num_features: Number of input features
        num_classes: Number of classes
        
    Returns:
        tuple: (best_model, best_params, train_predictions, val_predictions, train_f1, val_f1)
    """
    # Initialize the evaluator for F1 score
    evaluator = MulticlassClassificationEvaluator(
        labelCol='label', 
        predictionCol='prediction', 
        metricName='f1'
    )
    
    # Define different network architectures for random search
    layers_options = [
        [num_features, num_features, num_classes],  # Simple network
        [num_features, num_features * 2, num_features, num_classes],  # Medium network
        [num_features, num_features * 2, num_features * 2, num_features, num_classes]  # Complex network
    ]
    
    # Define parameter combinations
    block_sizes = [64, 128, 256]
    max_iters = [50, 100]
    learning_rates = [0.01, 0.03, 0.1]
    
    # Track best model and score
    best_mlp_model = None
    best_mlp_val_f1 = 0
    best_mlp_params = {}
    best_train_predictions = None
    best_val_predictions = None
    best_mlp_train_f1 = 0
    
    # Start timing
    start_time = time.time()
    
    print("Training MLP models with random search...")
    total_combinations = len(layers_options) * len(block_sizes) * len(max_iters) * len(learning_rates)
    current_combination = 0
    
    # Manually iterate through parameter combinations
    for layers in layers_options:
        for block_size in block_sizes:
            for max_iter in max_iters:
                for step_size in learning_rates:
                    current_combination += 1
                    print(f"\rTrying combination {current_combination}/{total_combinations}", end="")
                    
                    # Initialize MLP with current parameters
                    mlp = MultilayerPerceptronClassifier(
                        labelCol="label",
                        featuresCol="features",
                        layers=layers,
                        blockSize=block_size,
                        maxIter=max_iter,
                        stepSize=step_size,
                        seed=42
                    )
                    
                    # Train and evaluate the model
                    mlp_model = mlp.fit(train_data)
                    train_predictions = mlp_model.transform(train_data)
                    val_predictions = mlp_model.transform(val_data)
                    mlp_train_f1 = evaluator.evaluate(train_predictions)
                    mlp_val_f1 = evaluator.evaluate(val_predictions)
                    
                    # Update best model if this one is better
                    if mlp_val_f1 > best_mlp_val_f1:
                        best_mlp_val_f1 = mlp_val_f1
                        best_mlp_train_f1 = mlp_train_f1
                        best_mlp_model = mlp_model
                        best_train_predictions = train_predictions
                        best_val_predictions = val_predictions
                        best_mlp_params = {
                            'layers': layers,
                            'blockSize': block_size,
                            'maxIter': max_iter,
                            'stepSize': step_size
                        }
    
    # End timing
    training_time = time.time() - start_time
    print(f"\nTraining completed in {training_time:.2f} seconds")
    
    return best_mlp_model, best_mlp_params, best_train_predictions, best_val_predictions, best_mlp_train_f1, best_mlp_val_f1

### 5.2 Evaluation Functions

In [None]:
# Helper function to convert PySpark predictions to numpy arrays for plotting
def get_prediction_labels(predictions_df):
    """Extract prediction and label columns from PySpark DataFrame."""
    pred_labels = predictions_df.select("prediction", "label").toPandas()
    y_pred = pred_labels["prediction"].values
    y_true = pred_labels["label"].values
    return y_pred, y_true

# Helper function to get prediction probabilities
def get_prediction_probabilities(predictions_df):
    """Extract probability column from PySpark DataFrame."""
    # Handle the warning about Arrow conversion by manually converting to NumPy
    probability_rows = predictions_df.select("probability").collect()
    return np.array([row.probability.toArray() for row in probability_rows])

# Helper function to plot confusion matrix
def plot_confusion_matrix(y_true, y_pred, title="Confusion Matrix", class_names=None, normalize=False):
    """Plot confusion matrix using seaborn with enhanced visualization."""
    cm = confusion_matrix(y_true, y_pred)
    
    # Normalization option
    if normalize:
        cm_sum = np.sum(cm, axis=1, keepdims=True)
        cm_percentage = cm.astype('float') / cm_sum * 100
        fmt = '.1f'
        cm_display = cm_percentage
    else:
        fmt = 'd'
        cm_display = cm
    
    plt.figure(figsize=(10, 8))
    
    # Create a more detailed heatmap
    ax = sns.heatmap(cm_display, annot=True, fmt=fmt, cmap="Blues", cbar=True,
                     linewidths=1, linecolor='black')
    
    # Add row and column totals
    row_sums = np.sum(cm, axis=1)
    col_sums = np.sum(cm, axis=0)
    
    # Add class names if provided
    if class_names is not None:
        tick_labels = class_names
    else:
        tick_labels = np.arange(len(cm))
    
    plt.xticks(np.arange(len(tick_labels)) + 0.5, tick_labels)
    plt.yticks(np.arange(len(tick_labels)) + 0.5, tick_labels)
    
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    
    # Add accuracy text
    accuracy = np.sum(np.diag(cm)) / np.sum(cm)
    plt.text(len(cm)/2, -0.5, f"Accuracy: {accuracy:.4f}", ha='center', fontsize=12)
    
    plt.tight_layout()
    plt.show()
    
    # Also show the normalized version if not already showing it
    if not normalize:
        plot_confusion_matrix(y_true, y_pred, f"{title} (Normalized %)", class_names, normalize=True)

# Helper function to print classification report
def print_classification_report(y_true, y_pred, class_names=None):
    """Print classification report with precision, recall, and F1 scores."""
    target_names = class_names if class_names is not None else None
    report = classification_report(y_true, y_pred, target_names=target_names)
    print("Classification Report:")
    print(report)
    
    # Calculate per-class metrics
    precision = {}
    recall = {}
    f1 = {}
    support = {}
    
    for class_idx in np.unique(y_true):
        true_positives = np.sum((y_true == class_idx) & (y_pred == class_idx))
        false_positives = np.sum((y_true != class_idx) & (y_pred == class_idx))
        false_negatives = np.sum((y_true == class_idx) & (y_pred != class_idx))
        
        class_support = np.sum(y_true == class_idx)
        
        precision[class_idx] = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
        recall[class_idx] = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
        f1[class_idx] = 2 * precision[class_idx] * recall[class_idx] / (precision[class_idx] + recall[class_idx]) if (precision[class_idx] + recall[class_idx]) > 0 else 0
        support[class_idx] = class_support
    
    # Visual representation of per-class metrics
    plt.figure(figsize=(12, 8))
    
    # Get unique classes ensuring they're sorted
    classes = sorted(np.unique(y_true))
    x = np.arange(len(classes))
    width = 0.25
    
    # Plot bars for each metric
    precision_vals = [precision[cls] for cls in classes]
    recall_vals = [recall[cls] for cls in classes]
    f1_vals = [f1[cls] for cls in classes]
    
    plt.bar(x - width, precision_vals, width, label='Precision')
    plt.bar(x, recall_vals, width, label='Recall')
    plt.bar(x + width, f1_vals, width, label='F1')
    
    plt.ylabel('Score')
    plt.title('Per-Class Performance Metrics')
    plt.xticks(x, [f'Class {cls}' for cls in classes])
    plt.ylim(0, 1.1)  # Metrics are between 0 and 1
    plt.legend()
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()
    
    # Display a bar chart of support (class frequencies)
    plt.figure(figsize=(10, 6))
    support_vals = [support[cls] for cls in classes]
    plt.bar(x, support_vals)
    plt.ylabel('Support (Number of Samples)')
    plt.title('Class Distribution in Evaluation Set')
    plt.xticks(x, [f'Class {cls}' for cls in classes])
    
    # Add value labels
    for i, v in enumerate(support_vals):
        plt.text(i, v + 0.5, str(v), ha='center')
        
    plt.tight_layout()
    plt.show()

# Helper function to plot ROC curve for multi-class
def plot_roc_curve(y_true, y_pred_proba, n_classes):
    """Plot ROC curve for multi-class classification."""
    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    
    # Convert to one-hot encoding for ROC calculation
    y_true_onehot = np.zeros((len(y_true), n_classes))
    for i in range(len(y_true)):
        y_true_onehot[i, int(y_true[i])] = 1
    
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_onehot[:, i], y_pred_proba[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    
    # Plot all ROC curves
    plt.figure(figsize=(12, 10))
    
    # Use a colormap for better differentiation between classes
    colors = plt.cm.jet(np.linspace(0, 1, n_classes))
    
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=2,
                 label=f'ROC curve of class {i} (area = {roc_auc[i]:.2f})')
    
    # Plot the diagonal
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    
    # Set plot properties
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Multi-class ROC Curve')
    plt.legend(loc="lower right")
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()
    
    # Return the average AUC
    return np.mean(list(roc_auc.values()))

# Helper function to plot precision-recall curve
def plot_precision_recall_curve(y_true, y_pred_proba, n_classes):
    """Plot Precision-Recall curve for multi-class classification."""
    # Compute precision-recall pairs for different probability thresholds
    precision = dict()
    recall = dict()
    avg_precision = dict()
    
    # Convert to one-hot encoding
    y_true_onehot = np.zeros((len(y_true), n_classes))
    for i in range(len(y_true)):
        y_true_onehot[i, int(y_true[i])] = 1
    
    for i in range(n_classes):
        precision[i], recall[i], _ = precision_recall_curve(y_true_onehot[:, i], y_pred_proba[:, i])
        avg_precision[i] = np.mean(precision[i])
    
    # Plot precision-recall curve for each class
    plt.figure(figsize=(12, 10))
    
    # Use a colormap for better differentiation between classes
    colors = plt.cm.jet(np.linspace(0, 1, n_classes))
    
    for i, color in zip(range(n_classes), colors):
        plt.plot(recall[i], precision[i], color=color, lw=2,
                 label=f'Class {i} (avg precision = {avg_precision[i]:.2f})')
    
    # Set plot properties
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Multi-class Precision-Recall Curve')
    plt.legend(loc="lower left")
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.ylim([0.0, 1.05])
    plt.xlim([0.0, 1.0])
    plt.tight_layout()
    plt.show()
    
    return np.mean(list(avg_precision.values()))

def evaluate_model(model_name, val_predictions, num_classes, has_probability=True):
    """Perform comprehensive evaluation of a model.
    
    Args:
        model_name: Name of the model
        val_predictions: PySpark DataFrame with predictions
        num_classes: Number of classes
        has_probability: Whether the model outputs probability scores
        
    Returns:
        tuple: (y_pred, y_true, y_pred_proba)
    """
    print(f"\n--- {model_name} Evaluation ---")
    
    # Calculate various metrics
    evaluator_f1 = MulticlassClassificationEvaluator(
        labelCol='label', predictionCol='prediction', metricName='f1')
    evaluator_precision = MulticlassClassificationEvaluator(
        labelCol='label', predictionCol='prediction', metricName='weightedPrecision')
    evaluator_recall = MulticlassClassificationEvaluator(
        labelCol='label', predictionCol='prediction', metricName='weightedRecall')
    evaluator_accuracy = MulticlassClassificationEvaluator(
        labelCol='label', predictionCol='prediction', metricName='accuracy')
    
    f1 = evaluator_f1.evaluate(val_predictions)
    precision = evaluator_precision.evaluate(val_predictions)
    recall = evaluator_recall.evaluate(val_predictions)
    accuracy = evaluator_accuracy.evaluate(val_predictions)
    
    print(f"\nOverall Metrics:")
    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    
    # Extract predictions and true labels
    y_pred, y_true = get_prediction_labels(val_predictions)
    
    # Plot confusion matrix
    print("\nConfusion Matrix:")
    plot_confusion_matrix(y_true, y_pred, f"{model_name} Confusion Matrix")
    
    # Print classification report
    print(f"\n{model_name} Classification Report:")
    print_classification_report(y_true, y_pred)
    
    # If model has probability outputs, plot ROC and PR curves
    y_pred_proba = None
    if has_probability:
        try:
            y_pred_proba = get_prediction_probabilities(val_predictions)
            
            # Plot ROC curve
            print("\nROC Curve:")
            auc_score = plot_roc_curve(y_true, y_pred_proba, num_classes)
            print(f"{model_name} Average AUC: {auc_score:.4f}")
            
            # Plot Precision-Recall curve
            print("\nPrecision-Recall Curve:")
            avg_precision = plot_precision_recall_curve(y_true, y_pred_proba, num_classes)
            print(f"{model_name} Average Precision: {avg_precision:.4f}")
        except Exception as e:
            print(f"Error generating probability-based metrics: {str(e)}")
            y_pred_proba = None
    
    return y_pred, y_true, y_pred_proba

## 6. Model Training

### 6.1 Logistic Regression

In [None]:
if RUN_MODELS['logistic_regression']:
    print("\n==== Logistic Regression Model ====")
    
    # Initialize the Logistic Regression model with optional class weights
    if IMBALANCE_CONFIG.get('use_class_weights', False) and 'weight' in train_data.columns:
        log_reg = LogisticRegression(labelCol='label', featuresCol='features', 
                                    predictionCol='prediction', weightCol='weight')
        print("Using class weights for Logistic Regression")
    else:
        log_reg = LogisticRegression(labelCol='label', featuresCol='features', 
                                    predictionCol='prediction')
    
    # Define the parameter grid for logistic regression
    lr_param_grid = ParamGridBuilder() \
        .addGrid(log_reg.regParam, [0.01, 0.1, 1.0]) \
        .addGrid(log_reg.elasticNetParam, [0.0, 0.5, 1.0]) \
        .addGrid(log_reg.maxIter, [10, 20]) \
        .addGrid(log_reg.family, ["multinomial"]) \
        .build()
    
    # Perform random search
    lr_best_model, lr_best_params, lr_train_preds, lr_val_preds, lr_train_f1, lr_val_f1 = perform_random_search(
        log_reg, 
        lr_param_grid, 
        train_data, 
        val_data, 
        num_folds=RANDOM_SEARCH_CONFIG['num_folds'],
        parallelism=RANDOM_SEARCH_CONFIG['parallelism'],
        use_f1=True,  # Use F1 for imbalanced classification
        class_weights=class_weights
    )
    
    # Print best parameters and performance
    print("\nBest Logistic Regression Parameters:")
    for param, value in lr_best_params.items():
        print(f"  {param}: {value}")
    
    print(f"\nLogistic Regression - Training F1 Score: {lr_train_f1:.4f}")
    print(f"Logistic Regression - Validation F1 Score: {lr_val_f1:.4f}")
    
    # Run comprehensive evaluation
    lr_y_pred, lr_y_true, lr_y_pred_proba = evaluate_model("Logistic Regression", lr_val_preds, num_classes)
else:
    print("Skipping Logistic Regression")

### 6.2 Random Forest

In [None]:
if RUN_MODELS['random_forest']:
    print("\n==== Random Forest Model ====")
    
    # Initialize Random Forest Classifier with class weights if available
    if IMBALANCE_CONFIG.get('use_class_weights', False) and 'weight' in train_data.columns:
        rf = RandomForestClassifier(labelCol="label", featuresCol="features", weightCol="weight")
        print("Using class weights for Random Forest")
    else:
        rf = RandomForestClassifier(labelCol="label", featuresCol="features")
    
    # Define parameter grid for Random Forest
    rf_param_grid = ParamGridBuilder() \
        .addGrid(rf.numTrees, [50, 100]) \
        .addGrid(rf.maxDepth, [5, 10, 15]) \
        .addGrid(rf.impurity, ["gini", "entropy"]) \
        .addGrid(rf.minInstancesPerNode, [1, 2]) \
        .build()
    
    # Perform random search
    rf_best_model, rf_best_params, rf_train_preds, rf_val_preds, rf_train_f1, rf_val_f1 = perform_random_search(
        rf, 
        rf_param_grid, 
        train_data, 
        val_data, 
        num_folds=RANDOM_SEARCH_CONFIG['num_folds'],
        parallelism=RANDOM_SEARCH_CONFIG['parallelism'],
        use_f1=True,
        class_weights=class_weights
    )
    
    # Print best parameters and performance
    print("\nBest Random Forest Parameters:")
    for param, value in rf_best_params.items():
        print(f"  {param}: {value}")
    
    print(f"\nRandom Forest - Training F1 Score: {rf_train_f1:.4f}")
    print(f"Random Forest - Validation F1 Score: {rf_val_f1:.4f}")
    
    # Run comprehensive evaluation
    rf_y_pred, rf_y_true, rf_y_pred_proba = evaluate_model("Random Forest", rf_val_preds, num_classes)
else:
    print("Skipping Random Forest")

### 6.3 Gradient Boosted Trees (GBDT)

In [None]:
if RUN_MODELS['gbdt']:
    print("\n==== Gradient Boosted Trees Model ====")
    
    # Initialize GBT Classifier with class weights if available
    if IMBALANCE_CONFIG.get('use_class_weights', False) and 'weight' in train_data.columns:
        gbt = GBTClassifier(labelCol="label", featuresCol="features", weightCol="weight")
        print("Using class weights for GBDT")
    else:
        gbt = GBTClassifier(labelCol="label", featuresCol="features")
    
    # Define parameter grid for GBT
    gbt_param_grid = ParamGridBuilder() \
        .addGrid(gbt.maxIter, [10, 20]) \
        .addGrid(gbt.maxDepth, [3, 5]) \
        .addGrid(gbt.stepSize, [0.05, 0.1]) \
        .addGrid(gbt.minInstancesPerNode, [1, 2]) \
        .build()
    
    # Perform random search
    gbt_best_model, gbt_best_params, gbt_train_preds, gbt_val_preds, gbt_train_f1, gbt_val_f1 = perform_random_search(
        gbt, 
        gbt_param_grid, 
        train_data, 
        val_data, 
        num_folds=RANDOM_SEARCH_CONFIG['num_folds'],
        parallelism=RANDOM_SEARCH_CONFIG['parallelism'],
        use_f1=True,
        class_weights=class_weights
    )
    
    # Print best parameters and performance
    print("\nBest GBDT Parameters:")
    for param, value in gbt_best_params.items():
        print(f"  {param}: {value}")
    
    print(f"\nGBDT - Training F1 Score: {gbt_train_f1:.4f}")
    print(f"GBDT - Validation F1 Score: {gbt_val_f1:.4f}")
    
    # GBT models don't provide probability outputs in Spark
    gbt_y_pred, gbt_y_true, _ = evaluate_model("GBDT", gbt_val_preds, num_classes, has_probability=False)
else:
    print("Skipping GBDT")

### 6.4 Multilayer Perceptron (MLP)

In [None]:
if RUN_MODELS['mlp']:
    print("\n==== Multilayer Perceptron Model ====")
    
    # Get number of features
    num_features = len(train_data.select("features").first()[0])
    
    # Perform custom random search for MLP
    mlp_best_model, mlp_best_params, mlp_train_preds, mlp_val_preds, mlp_train_f1, mlp_val_f1 = perform_mlp_random_search(
        train_data, 
        val_data, 
        num_features, 
        num_classes
    )
    
    # Print best parameters and performance
    print("\nBest MLP Parameters:")
    for param, value in mlp_best_params.items():
        print(f"  {param}: {value}")
    
    print(f"\nMLP - Training F1 Score: {mlp_train_f1:.4f}")
    print(f"MLP - Validation F1 Score: {mlp_val_f1:.4f}")
    
    # MLP models don't provide probability outputs in Spark
    mlp_y_pred, mlp_y_true, _ = evaluate_model("MLP", mlp_val_preds, num_classes, has_probability=False)
else:
    print("Skipping MLP")

## 7. Model Comparison

In [None]:
# Collect model names and scores for models that were run
model_names = []
train_scores = []
val_scores = []

if RUN_MODELS['logistic_regression']:
    model_names.append("Logistic Regression")
    train_scores.append(lr_train_f1)
    val_scores.append(lr_val_f1)
    
if RUN_MODELS['random_forest']:
    model_names.append("Random Forest")
    train_scores.append(rf_train_f1)
    val_scores.append(rf_val_f1)
    
if RUN_MODELS['gbdt']:
    model_names.append("GBDT")
    train_scores.append(gbt_train_f1)
    val_scores.append(gbt_val_f1)
    
if RUN_MODELS['mlp']:
    model_names.append("MLP")
    train_scores.append(mlp_train_f1)
    val_scores.append(mlp_val_f1)

# Check if we have any models to compare
if not model_names:
    print("No models were run for comparison.")
else:
    # Create a comparison DataFrame
    model_comparison = pd.DataFrame({
        'Model': model_names,
        'Training F1': train_scores,
        'Validation F1': val_scores,
        'Difference (Train-Val)': [train - val for train, val in zip(train_scores, val_scores)]
    })
    
    # Sort by validation F1 score, descending
    model_comparison = model_comparison.sort_values('Validation F1', ascending=False).reset_index(drop=True)
    
    print("Model Performance Comparison:")
    print(model_comparison)
    
    # Plot model comparison
    plt.figure(figsize=(12, 6))
    
    # Reorder based on sorted DataFrame
    sorted_models = model_comparison['Model'].tolist()
    sorted_train = model_comparison['Training F1'].tolist()
    sorted_val = model_comparison['Validation F1'].tolist()
    
    ind = np.arange(len(sorted_models))
    width = 0.35
    
    plt.bar(ind - width/2, sorted_train, width, label='Training F1', color='skyblue')
    plt.bar(ind + width/2, sorted_val, width, label='Validation F1', color='salmon')
    
    plt.ylabel('F1 Score')
    plt.title('Model Comparison (Sorted by Validation F1)')
    plt.xticks(ind, sorted_models, rotation=15)
    plt.legend(loc='best')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add value labels on top of bars
    for i, v in enumerate(sorted_train):
        plt.text(i - width/2, v + 0.01, f'{v:.4f}', ha='center')
        
    for i, v in enumerate(sorted_val):
        plt.text(i + width/2, v + 0.01, f'{v:.4f}', ha='center')
    
    plt.tight_layout()
    plt.show()

## 8. Threshold Optimization

In [None]:
def optimize_threshold(model_name, predictions_df, y_pred_proba=None, optimize_for='f1'):
    """Optimize prediction threshold for imbalanced classes.
    
    Args:
        model_name: Name of the model
        predictions_df: PySpark DataFrame with predictions
        y_pred_proba: Optional numpy array of prediction probabilities
        optimize_for: Metric to optimize ('f1', 'recall', 'precision', or 'balanced')
    
    Returns:
        tuple: (optimized_df, optimal_thresholds)
    """
    if y_pred_proba is None or not IMBALANCE_CONFIG.get('adjust_threshold', False):
        print(f"\nSkipping threshold optimization for {model_name}")
        return predictions_df, None
    
    print(f"\nOptimizing thresholds for {model_name}...")
    
    # Extract true labels
    y_true = predictions_df.select("label").toPandas()["label"].values
    
    # Find optimal thresholds for each class
    optimal_thresholds = []
    
    # Convert to one-hot encoding
    y_true_onehot = np.zeros((len(y_true), num_classes))
    for i in range(len(y_true)):
        y_true_onehot[i, int(y_true[i])] = 1
    
    # For each class, find threshold that maximizes chosen metric
    for i in range(num_classes):
        best_metric = 0
        best_threshold = 0.5  # Default
        
        # Try different thresholds
        for threshold in np.arange(0.1, 0.9, 0.05):
            # Create binary predictions using this threshold
            pred_i = (y_pred_proba[:, i] >= threshold).astype(int)
            
            # Calculate metrics
            tp = np.sum((pred_i == 1) & (y_true_onehot[:, i] == 1))
            fp = np.sum((pred_i == 1) & (y_true_onehot[:, i] == 0))
            fn = np.sum((pred_i == 0) & (y_true_onehot[:, i] == 1))
            
            # Avoid division by zero
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            
            # Calculate metric based on what we're optimizing for
            if optimize_for == 'f1':
                metric = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            elif optimize_for == 'recall':
                metric = recall
            elif optimize_for == 'precision':
                metric = precision
            elif optimize_for == 'balanced':
                metric = (2 * recall + precision) / 3  # Emphasize recall more
            
            # Update if better
            if metric > best_metric:
                best_metric = metric
                best_threshold = threshold
        
        optimal_thresholds.append(best_threshold)
        
        # For minority classes, we may want to lower the threshold further
        if i > 0:  # Assuming class 0 is majority
            optimal_thresholds[i] = max(0.1, optimal_thresholds[i] * 0.8)  # Lower by 20%
    
    print(f"Optimal thresholds: {optimal_thresholds}")
    
    # Create a UDF to apply these thresholds
    def create_adjusted_prediction(prob_vector, thresholds):
        """Adjust prediction based on optimized thresholds."""
        def apply_thresholds(prob_array):
            max_score = -1
            max_class = 0
            for i, (prob, threshold) in enumerate(zip(prob_array, thresholds)):
                # Calculate score as how much it exceeds its threshold
                score = prob / threshold if threshold > 0 else 0
                if score > max_score:
                    max_score = score
                    max_class = i
            return float(max_class)
        return udf(apply_thresholds, returnType=DoubleType())
    
    # Apply the optimized thresholds
    threshold_udf = create_adjusted_prediction("probability", optimal_thresholds)
    optimized_df = predictions_df.withColumn("optimized_prediction", threshold_udf(col("probability")))
    
    # Evaluate the optimized predictions
    print("\nEvaluation after threshold optimization:")
    
    # Calculate metrics with optimized predictions
    evaluator_f1 = MulticlassClassificationEvaluator(
        labelCol='label', predictionCol='optimized_prediction', metricName='f1')
    evaluator_precision = MulticlassClassificationEvaluator(
        labelCol='label', predictionCol='optimized_prediction', metricName='weightedPrecision')
    evaluator_recall = MulticlassClassificationEvaluator(
        labelCol='label', predictionCol='optimized_prediction', metricName='weightedRecall')
    
    f1 = evaluator_f1.evaluate(optimized_df)
    precision = evaluator_precision.evaluate(optimized_df)
    recall = evaluator_recall.evaluate(optimized_df)
    
    print(f"F1 Score (Optimized): {f1:.4f}")
    print(f"Precision (Optimized): {precision:.4f}")
    print(f"Recall (Optimized): {recall:.4f}")
    
    # Extract and display confusion matrix
    y_pred_opt = optimized_df.select("optimized_prediction", "label").toPandas()
    y_pred_opt_values = y_pred_opt["optimized_prediction"].values
    y_true_values = y_pred_opt["label"].values
    
    print("\nConfusion Matrix with Optimized Thresholds:")
    plot_confusion_matrix(y_true_values, y_pred_opt_values, 
                         f"{model_name} Confusion Matrix (Optimized Thresholds)")
    
    # Print classification report
    print(f"\n{model_name} Classification Report (Optimized Thresholds):")
    print_classification_report(y_true_values, y_pred_opt_values)
    
    return optimized_df, optimal_thresholds

# Apply threshold optimization to models with probability output
if RUN_MODELS['logistic_regression'] and IMBALANCE_CONFIG.get('adjust_threshold', False):
    lr_val_preds_opt, lr_thresholds = optimize_threshold(
        "Logistic Regression", lr_val_preds, lr_y_pred_proba, optimize_for='balanced')

if RUN_MODELS['random_forest'] and IMBALANCE_CONFIG.get('adjust_threshold', False):
    rf_val_preds_opt, rf_thresholds = optimize_threshold(
        "Random Forest", rf_val_preds, rf_y_pred_proba, optimize_for='balanced')

## 9. Test Set Evaluation with Best Model

In [None]:
if model_names:  # Only run if we have models
    # Find the best model based on validation F1 scores
    best_model_index = val_scores.index(max(val_scores))
    best_model_name = model_names[best_model_index]
    print(f"Best Model: {best_model_name} with Validation F1: {max(val_scores):.4f}")
    
    # Get the corresponding model object and thresholds
    best_thresholds = None
    if best_model_name == "Logistic Regression" and RUN_MODELS['logistic_regression']:
        best_model = lr_best_model
        has_probability = True
        if IMBALANCE_CONFIG.get('adjust_threshold', False):
            best_thresholds = lr_thresholds
    elif best_model_name == "Random Forest" and RUN_MODELS['random_forest']:
        best_model = rf_best_model
        has_probability = True
        if IMBALANCE_CONFIG.get('adjust_threshold', False):
            best_thresholds = rf_thresholds
    elif best_model_name == "GBDT" and RUN_MODELS['gbdt']:
        best_model = gbt_best_model
        has_probability = False
    elif best_model_name == "MLP" and RUN_MODELS['mlp']:
        best_model = mlp_best_model
        has_probability = False
    
    # Make predictions on the test set
    test_predictions = best_model.transform(test_data)
    
    # Apply optimized thresholds if available
    if best_thresholds is not None and has_probability:
        print(f"\nApplying optimized thresholds to test set predictions...")
        
        # Create a UDF to apply these thresholds
        def create_adjusted_prediction(prob_vector, thresholds):
            """Adjust prediction based on optimized thresholds."""
            def apply_thresholds(prob_array):
                max_score = -1
                max_class = 0
                for i, (prob, threshold) in enumerate(zip(prob_array, thresholds)):
                    # Calculate score as how much it exceeds its threshold
                    score = prob / threshold if threshold > 0 else 0
                    if score > max_score:
                        max_score = score
                        max_class = i
                return float(max_class)
            return udf(apply_thresholds, returnType=DoubleType())
        
        # Apply the optimized thresholds
        threshold_udf = create_adjusted_prediction("probability", best_thresholds)
        test_predictions = test_predictions.withColumn("prediction", threshold_udf(col("probability")))
    
    # Initialize the evaluator for F1 score
    evaluator = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', metricName='f1')
    
    # Evaluate on test set
    test_f1 = evaluator.evaluate(test_predictions)
    print(f"Test F1 Score with {best_model_name}: {test_f1:.4f}")
    
    # Run comprehensive evaluation on test set
    print(f"\n--- {best_model_name} Test Set Evaluation ---")
    test_y_pred, test_y_true, test_y_pred_proba = evaluate_model(
        f"{best_model_name} (Test)", 
        test_predictions, 
        num_classes, 
        has_probability=has_probability
    )
else:
    print("No models were run, skipping test set evaluation.")

## 10. Conclusion

### Summary of Techniques for Handling Class Imbalance

In this notebook, we implemented a comprehensive approach to deal with severe class imbalance in our classification task:

1. **Class Distribution Analysis**:
   - Identified extreme imbalance with Class 0 dominating the dataset
   - Visualized and quantified the imbalance ratio

2. **Data-Level Techniques**:
   - **Undersampling**: Reduced the majority class to create more balanced training
   - **Oversampling**: Increased minority class representation
   - **Combined approach**: Applied both techniques for optimal balance

3. **Algorithm-Level Techniques**:
   - **Class Weights**: Added weights to make misclassifying minority classes more costly
   - **Threshold Optimization**: Adjusted prediction thresholds for each class
   - **Model Selection**: Used models that handle imbalance well (tree-based models)

4. **Evaluation Metrics**:
   - Focused on metrics beyond accuracy (F1, precision, recall)
   - Used per-class performance visualization
   - Enhanced confusion matrix visualization
   - Used Precision-Recall curves instead of just ROC

### Key Findings

1. The original highly imbalanced dataset (90%+ Class 0) led to models predicting mainly the majority class
2. Resampling techniques significantly improved minority class detection
3. Optimizing classification thresholds further improved recall for minority classes
4. The best model achieved a better balance of precision and recall across all classes

### Next Steps

1. Consider feature engineering specific to minority classes
2. Explore ensemble methods that combine multiple models
3. Gather more examples of minority classes if possible
4. Consider anomaly detection approaches for extremely rare classes