# PySpark MLlib Multi-Model Drug Interaction Prediction

This notebook implements and compares three different machine learning models using PySpark MLlib for drug interaction safety prediction:

1. **Logistic Regression** - Binary classification with regularization
2. **Random Forest Classifier** - Ensemble method for robust predictions
3. **Gradient Boosted Trees (GBT)** - Advanced boosting algorithm

## Key Features:
- **PySpark MLlib**: Distributed machine learning at scale
- **HDFS Integration**: Direct data loading from HDFS
- **Comprehensive Evaluation**: Multiple metrics and visualization
- **Model Comparison**: Side-by-side performance analysis

## Dataset:
- Source: HDFS path `hdfs://localhost:9000/output/combined_dataset_complete.csv`
- Features: Drug combinations, dosage information, safety labels
- Complete dataset processing with PySpark

In [None]:
# Section 1: Environment Setup and Imports
import warnings
warnings.filterwarnings('ignore')

# PySpark imports
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, lit, concat_ws, isnan, count, mean, stddev
from pyspark.sql.types import DoubleType

# PySpark ML imports
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder, StandardScaler
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml import Pipeline

# Visualization and metrics
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from datetime import datetime

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("‚úì All libraries imported successfully!")
print(f"‚è∞ Notebook started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

In [None]:
# Section 2: Initialize Spark Session
print("üöÄ Initializing Spark Session...")

spark = SparkSession.builder \
    .appName("DrugInteractionMLlib") \
    .master("local[*]") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .config("fs.defaultFS", "hdfs://localhost:9000") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "4g") \
    .config("spark.sql.shuffle.partitions", "200") \
    .getOrCreate()

# Set log level to reduce verbosity
spark.sparkContext.setLogLevel("WARN")

print("‚úì Spark Session initialized successfully!")
print(f"   Spark Version: {spark.version}")
print(f"   Master: {spark.sparkContext.master}")
print(f"   App Name: {spark.sparkContext.appName}")
print("\n" + "="*60)

In [None]:
# Section 3: Load Data from HDFS
print("üìä Loading Drug Interaction Dataset from HDFS...")
print("   Source: hdfs://localhost:9000/output/combined_dataset_complete.csv")

hdfs_path = "hdfs://localhost:9000/output/combined_dataset_complete.csv"

try:
    # Load data from HDFS
    df = spark.read \
        .option("header", "true") \
        .option("inferSchema", "true") \
        .csv(hdfs_path)
    
    print(f"   ‚úì Dataset loaded successfully!")
    print(f"   Total records: {df.count():,}")
    print(f"   Total columns: {len(df.columns)}")
    
    # Display schema
    print("\nüìã Dataset Schema:")
    df.printSchema()
    
    # Display sample data
    print("\nüìù Sample Records:")
    df.show(5, truncate=False)
    
    # Show statistics
    print("\nüìà Dataset Statistics:")
    df.groupBy("safety_label").count().show()
    
except Exception as e:
    print(f"   ‚ùå Error loading dataset: {str(e)}")
    print("   Please ensure HDFS is running and the dataset exists at the specified path")
    raise

print("\n" + "="*60)

In [None]:
# Section 4: Data Preprocessing
print("üîÑ Preprocessing data for MLlib models...")

# Identify drug columns (drug1 through drug10)
drug_columns = [col_name for col_name in df.columns if col_name.startswith('drug') and col_name[4:].isdigit()]
print(f"   Found {len(drug_columns)} drug columns: {drug_columns}")

# Create a clean dataset with non-null values
df_clean = df.na.drop(subset=['safety_label'])

# Convert safety_label to numeric (0 for safe, 1 for unsafe)
df_clean = df_clean.withColumn(
    "label",
    when(col("safety_label") == "unsafe", 1.0).otherwise(0.0)
)

print(f"   ‚úì Label encoding complete (safe=0, unsafe=1)")

# Create numerical features
# Handle total_drugs if it exists
if 'total_drugs' in df_clean.columns:
    df_clean = df_clean.withColumn('total_drugs', col('total_drugs').cast(DoubleType()))
else:
    # Count non-null drugs
    from pyspark.sql.functions import coalesce
    df_clean = df_clean.withColumn(
        'total_drugs',
        sum([when(col(drug_col).isNotNull(), 1).otherwise(0) for drug_col in drug_columns])
    )

# Handle doses_per_24_hrs if it exists
if 'doses_per_24_hrs' in df_clean.columns:
    df_clean = df_clean.withColumn(
        'doses_per_24_hrs_numeric',
        when(col('doses_per_24_hrs').isNotNull(), col('doses_per_24_hrs').cast(DoubleType())).otherwise(0.0)
    )
    numerical_features = ['total_drugs', 'doses_per_24_hrs_numeric']
else:
    numerical_features = ['total_drugs']

# Fill null values in numerical features
for feature in numerical_features:
    df_clean = df_clean.fillna({feature: 0.0})

print(f"   ‚úì Numerical features prepared: {numerical_features}")

# String indexing for drug columns (create indices for each drug)
indexed_features = []
indexers = []

for drug_col in drug_columns[:3]:  # Use first 3 drug columns to keep feature space manageable
    if drug_col in df_clean.columns:
        # Fill null values with 'NONE'
        df_clean = df_clean.fillna({drug_col: 'NONE'})
        
        indexer = StringIndexer(
            inputCol=drug_col,
            outputCol=f"{drug_col}_index",
            handleInvalid="keep"
        )
        indexers.append(indexer)
        indexed_features.append(f"{drug_col}_index")

print(f"   ‚úì Drug indexing configured for {len(indexers)} columns")

# Combine all features
feature_columns = numerical_features + indexed_features

print(f"   ‚úì Total feature columns: {len(feature_columns)}")
print(f"\n   Feature list: {feature_columns}")

# Show class distribution
print("\nüìä Class Distribution:")
df_clean.groupBy("label").count().show()

print("\n" + "="*60)

In [None]:
# Section 5: Train-Test Split
print("üìÇ Splitting data into training and test sets...")

# Split data (80% training, 20% testing)
train_data, test_data = df_clean.randomSplit([0.8, 0.2], seed=42)

train_count = train_data.count()
test_count = test_data.count()

print(f"   ‚úì Training set: {train_count:,} records ({train_count/(train_count+test_count)*100:.1f}%)")
print(f"   ‚úì Test set: {test_count:,} records ({test_count/(train_count+test_count)*100:.1f}%)")

# Show label distribution in train and test
print("\n   Training set distribution:")
train_data.groupBy("label").count().show()

print("   Test set distribution:")
test_data.groupBy("label").count().show()

print("\n" + "="*60)

In [None]:
# Section 6: Model 1 - Logistic Regression
print("ü§ñ Training Model 1: Logistic Regression")
print("="*60)

# Build pipeline for Logistic Regression
lr_pipeline_stages = indexers.copy()

# Vector assembler
lr_assembler = VectorAssembler(
    inputCols=feature_columns,
    outputCol="features",
    handleInvalid="keep"
)
lr_pipeline_stages.append(lr_assembler)

# Standard scaler for numerical stability
lr_scaler = StandardScaler(
    inputCol="features",
    outputCol="scaled_features"
)
lr_pipeline_stages.append(lr_scaler)

# Logistic Regression model
lr = LogisticRegression(
    featuresCol="scaled_features",
    labelCol="label",
    maxIter=100,
    regParam=0.01,
    elasticNetParam=0.0,
    family="binomial"
)
lr_pipeline_stages.append(lr)

# Create and fit pipeline
lr_pipeline = Pipeline(stages=lr_pipeline_stages)

print("   üîÑ Training Logistic Regression model...")
import time
start_time = time.time()

lr_model = lr_pipeline.fit(train_data)

training_time = time.time() - start_time
print(f"   ‚úì Model trained in {training_time:.2f} seconds")

# Make predictions
lr_predictions = lr_model.transform(test_data)

print("\n   üìä Sample Predictions:")
lr_predictions.select("label", "prediction", "probability").show(10, truncate=False)

print("\n" + "="*60)

In [None]:
# Section 7: Model 2 - Random Forest Classifier
print("üå≤ Training Model 2: Random Forest Classifier")
print("="*60)

# Build pipeline for Random Forest
rf_pipeline_stages = indexers.copy()

# Vector assembler
rf_assembler = VectorAssembler(
    inputCols=feature_columns,
    outputCol="features",
    handleInvalid="keep"
)
rf_pipeline_stages.append(rf_assembler)

# Random Forest model
rf = RandomForestClassifier(
    featuresCol="features",
    labelCol="label",
    numTrees=100,
    maxDepth=10,
    minInstancesPerNode=1,
    seed=42
)
rf_pipeline_stages.append(rf)

# Create and fit pipeline
rf_pipeline = Pipeline(stages=rf_pipeline_stages)

print("   üîÑ Training Random Forest model...")
start_time = time.time()

rf_model = rf_pipeline.fit(train_data)

training_time = time.time() - start_time
print(f"   ‚úì Model trained in {training_time:.2f} seconds")

# Make predictions
rf_predictions = rf_model.transform(test_data)

print("\n   üìä Sample Predictions:")
rf_predictions.select("label", "prediction", "probability").show(10, truncate=False)

# Get feature importances
rf_classifier = rf_model.stages[-1]
feature_importances = rf_classifier.featureImportances
print(f"\n   üìà Top Feature Importances:")
for i, importance in enumerate(feature_importances.toArray()[:10]):
    if i < len(feature_columns):
        print(f"      {feature_columns[i]}: {importance:.4f}")

print("\n" + "="*60)

In [None]:
# Section 8: Model 3 - Gradient Boosted Trees
print("üöÄ Training Model 3: Gradient Boosted Trees (GBT)")
print("="*60)

# Build pipeline for GBT
gbt_pipeline_stages = indexers.copy()

# Vector assembler
gbt_assembler = VectorAssembler(
    inputCols=feature_columns,
    outputCol="features",
    handleInvalid="keep"
)
gbt_pipeline_stages.append(gbt_assembler)

# Gradient Boosted Trees model
gbt = GBTClassifier(
    featuresCol="features",
    labelCol="label",
    maxIter=50,
    maxDepth=5,
    stepSize=0.1,
    seed=42
)
gbt_pipeline_stages.append(gbt)

# Create and fit pipeline
gbt_pipeline = Pipeline(stages=gbt_pipeline_stages)

print("   üîÑ Training Gradient Boosted Trees model...")
start_time = time.time()

gbt_model = gbt_pipeline.fit(train_data)

training_time = time.time() - start_time
print(f"   ‚úì Model trained in {training_time:.2f} seconds")

# Make predictions
gbt_predictions = gbt_model.transform(test_data)

print("\n   üìä Sample Predictions:")
gbt_predictions.select("label", "prediction", "probability").show(10, truncate=False)

# Get feature importances
gbt_classifier = gbt_model.stages[-1]
feature_importances = gbt_classifier.featureImportances
print(f"\n   üìà Top Feature Importances:")
for i, importance in enumerate(feature_importances.toArray()[:10]):
    if i < len(feature_columns):
        print(f"      {feature_columns[i]}: {importance:.4f}")

print("\n" + "="*60)

In [None]:
# Section 9: Model Evaluation and Metrics
print("üìä Evaluating All Models")
print("="*60)

# Initialize evaluators
binary_evaluator = BinaryClassificationEvaluator(labelCol="label")
multiclass_evaluator = MulticlassClassificationEvaluator(labelCol="label")

models_info = [
    ("Logistic Regression", lr_predictions),
    ("Random Forest", rf_predictions),
    ("Gradient Boosted Trees", gbt_predictions)
]

results = []

for model_name, predictions in models_info:
    print(f"\nüîç Evaluating {model_name}:")
    print("-" * 60)
    
    # Calculate metrics
    auc = binary_evaluator.evaluate(predictions, {binary_evaluator.metricName: "areaUnderROC"})
    pr_auc = binary_evaluator.evaluate(predictions, {binary_evaluator.metricName: "areaUnderPR"})
    
    accuracy = multiclass_evaluator.evaluate(predictions, {multiclass_evaluator.metricName: "accuracy"})
    precision = multiclass_evaluator.evaluate(predictions, {multiclass_evaluator.metricName: "weightedPrecision"})
    recall = multiclass_evaluator.evaluate(predictions, {multiclass_evaluator.metricName: "weightedRecall"})
    f1 = multiclass_evaluator.evaluate(predictions, {multiclass_evaluator.metricName: "f1"})
    
    # Store results
    result = {
        'Model': model_name,
        'Accuracy': accuracy,
        'Precision': precision,
        'Recall': recall,
        'F1-Score': f1,
        'ROC-AUC': auc,
        'PR-AUC': pr_auc
    }
    results.append(result)
    
    # Print metrics
    print(f"   Accuracy:  {accuracy:.4f}")
    print(f"   Precision: {precision:.4f}")
    print(f"   Recall:    {recall:.4f}")
    print(f"   F1-Score:  {f1:.4f}")
    print(f"   ROC-AUC:   {auc:.4f}")
    print(f"   PR-AUC:    {pr_auc:.4f}")

# Create results DataFrame
results_df = pd.DataFrame(results)
print("\n" + "="*60)
print("\nüìà Model Comparison Summary:")
print(results_df.to_string(index=False))

# Find best model
best_model_idx = results_df['ROC-AUC'].idxmax()
best_model = results_df.loc[best_model_idx, 'Model']
best_auc = results_df.loc[best_model_idx, 'ROC-AUC']

print(f"\nüèÜ Best Model: {best_model} (ROC-AUC: {best_auc:.4f})")
print("\n" + "="*60)

In [None]:
# Section 10: Confusion Matrices
print("üìä Generating Confusion Matrices")
print("="*60)

# Function to compute confusion matrix from predictions
def compute_confusion_matrix(predictions):
    # Collect predictions and labels
    pred_and_labels = predictions.select("prediction", "label").collect()
    
    # Initialize confusion matrix
    tp = fp = tn = fn = 0
    
    for row in pred_and_labels:
        pred, label = row['prediction'], row['label']
        if label == 1.0 and pred == 1.0:
            tp += 1
        elif label == 0.0 and pred == 1.0:
            fp += 1
        elif label == 0.0 and pred == 0.0:
            tn += 1
        elif label == 1.0 and pred == 0.0:
            fn += 1
    
    return np.array([[tn, fp], [fn, tp]])

# Create subplots
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
fig.suptitle('Confusion Matrices for All Models', fontsize=16, fontweight='bold')

for idx, (model_name, predictions) in enumerate(models_info):
    cm = compute_confusion_matrix(predictions)
    
    # Plot confusion matrix
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[idx],
                xticklabels=['Safe', 'Unsafe'],
                yticklabels=['Safe', 'Unsafe'])
    axes[idx].set_title(model_name, fontsize=12, fontweight='bold')
    axes[idx].set_ylabel('Actual')
    axes[idx].set_xlabel('Predicted')

plt.tight_layout()
plt.savefig('confusion_matrices.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úì Confusion matrices saved as 'confusion_matrices.png'")
print("\n" + "="*60)

In [None]:
# Section 11: ROC Curves
print("üìà Generating ROC Curves")
print("="*60)

# Function to compute ROC curve points
def compute_roc_curve(predictions, num_points=100):
    # Collect predictions and labels
    pred_data = predictions.select(
        col("label").cast("double").alias("label"),
        col("probability").getItem(1).alias("probability")
    ).collect()
    
    # Convert to numpy arrays
    labels = np.array([row['label'] for row in pred_data])
    probs = np.array([row['probability'] for row in pred_data])
    
    # Sort by probability
    sorted_indices = np.argsort(probs)[::-1]
    labels = labels[sorted_indices]
    
    # Calculate TPR and FPR
    tpr_list = [0]
    fpr_list = [0]
    
    total_positives = np.sum(labels == 1)
    total_negatives = np.sum(labels == 0)
    
    tp = 0
    fp = 0
    
    for label in labels:
        if label == 1:
            tp += 1
        else:
            fp += 1
        
        tpr = tp / total_positives if total_positives > 0 else 0
        fpr = fp / total_negatives if total_negatives > 0 else 0
        
        tpr_list.append(tpr)
        fpr_list.append(fpr)
    
    return np.array(fpr_list), np.array(tpr_list)

# Plot ROC curves
plt.figure(figsize=(10, 8))

colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']

for idx, (model_name, predictions) in enumerate(models_info):
    fpr, tpr = compute_roc_curve(predictions)
    auc_score = results_df.loc[idx, 'ROC-AUC']
    
    plt.plot(fpr, tpr, label=f'{model_name} (AUC = {auc_score:.4f})',
             color=colors[idx], linewidth=2)

# Plot diagonal line
plt.plot([0, 1], [0, 1], 'k--', label='Random Classifier', linewidth=1)

plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('ROC Curves - Model Comparison', fontsize=14, fontweight='bold')
plt.legend(loc='lower right', fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('roc_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úì ROC curves saved as 'roc_curves.png'")
print("\n" + "="*60)

In [None]:
# Section 12: Metrics Comparison Visualization
print("üìä Generating Metrics Comparison Charts")
print("="*60)

# Create bar plots for each metric
metrics_to_plot = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'ROC-AUC']

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Model Performance Comparison Across Metrics', fontsize=16, fontweight='bold')

for idx, metric in enumerate(metrics_to_plot):
    row = idx // 3
    col = idx % 3
    
    ax = axes[row, col]
    
    # Create bar plot
    bars = ax.bar(results_df['Model'], results_df[metric], 
                   color=['#FF6B6B', '#4ECDC4', '#45B7D1'])
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.4f}',
                ha='center', va='bottom', fontsize=9)
    
    ax.set_ylabel(metric, fontsize=11)
    ax.set_title(f'{metric} Comparison', fontsize=12, fontweight='bold')
    ax.set_ylim([0, 1.1])
    ax.grid(True, alpha=0.3, axis='y')
    
    # Rotate x-axis labels
    ax.set_xticklabels(results_df['Model'], rotation=45, ha='right')

# Remove empty subplot
fig.delaxes(axes[1, 2])

plt.tight_layout()
plt.savefig('metrics_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úì Metrics comparison saved as 'metrics_comparison.png'")
print("\n" + "="*60)

In [None]:
# Section 13: Feature Importance Visualization
print("üìä Generating Feature Importance Charts")
print("="*60)

# Extract feature importances from tree-based models
rf_classifier = rf_model.stages[-1]
gbt_classifier = gbt_model.stages[-1]

rf_importances = rf_classifier.featureImportances.toArray()
gbt_importances = gbt_classifier.featureImportances.toArray()

# Create feature importance dataframe
importance_df = pd.DataFrame({
    'Feature': feature_columns,
    'Random Forest': rf_importances[:len(feature_columns)],
    'GBT': gbt_importances[:len(feature_columns)]
})

# Sort by Random Forest importance
importance_df = importance_df.sort_values('Random Forest', ascending=False)

# Plot feature importances
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle('Feature Importance - Tree-Based Models', fontsize=16, fontweight='bold')

# Random Forest importance
importance_df.plot(x='Feature', y='Random Forest', kind='barh', ax=ax1, 
                   color='#4ECDC4', legend=False)
ax1.set_title('Random Forest Feature Importance', fontsize=12, fontweight='bold')
ax1.set_xlabel('Importance', fontsize=11)
ax1.set_ylabel('Features', fontsize=11)
ax1.grid(True, alpha=0.3, axis='x')

# GBT importance
importance_df_gbt = importance_df.sort_values('GBT', ascending=False)
importance_df_gbt.plot(x='Feature', y='GBT', kind='barh', ax=ax2, 
                       color='#45B7D1', legend=False)
ax2.set_title('Gradient Boosted Trees Feature Importance', fontsize=12, fontweight='bold')
ax2.set_xlabel('Importance', fontsize=11)
ax2.set_ylabel('Features', fontsize=11)
ax2.grid(True, alpha=0.3, axis='x')

plt.tight_layout()
plt.savefig('feature_importance.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úì Feature importance saved as 'feature_importance.png'")

print("\nüìã Top 10 Most Important Features:")
print(importance_df.head(10).to_string(index=False))

print("\n" + "="*60)

In [None]:
# Section 14: Final Summary and Model Persistence
print("\n" + "="*60)
print("üéâ MODEL TRAINING AND EVALUATION COMPLETE")
print("="*60)

print("\nüìä Final Results Summary:")
print("-" * 60)
print(results_df.to_string(index=False))

print(f"\nüèÜ Best Performing Model: {best_model}")
print(f"   ROC-AUC Score: {best_auc:.4f}")

print("\nüìÅ Generated Files:")
print("   ‚úì confusion_matrices.png - Confusion matrices for all models")
print("   ‚úì roc_curves.png - ROC curve comparison")
print("   ‚úì metrics_comparison.png - Performance metrics comparison")
print("   ‚úì feature_importance.png - Feature importance analysis")

print("\nüíæ Saving Best Model...")
best_model_path = f"best_model_{best_model.replace(' ', '_').lower()}"

if best_model == "Logistic Regression":
    lr_model.write().overwrite().save(best_model_path)
elif best_model == "Random Forest":
    rf_model.write().overwrite().save(best_model_path)
else:
    gbt_model.write().overwrite().save(best_model_path)

print(f"   ‚úì Best model saved to: {best_model_path}")

print("\n" + "="*60)
print(f"‚è∞ Analysis completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("="*60)

In [None]:
# Section 15: Cleanup
print("\nüßπ Cleaning up resources...")

# Stop Spark session
spark.stop()

print("‚úì Spark session stopped")
print("\n‚úÖ All operations completed successfully!")