# LightGBM Classification with Medical Claim Embeddings

This notebook demonstrates training LightGBM models using embeddings generated from medical claims data.

In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# LightGBM imports
import lightgbm as lgb
from lightgbm import LGBMClassifier

# Scikit-learn imports
from sklearn.model_selection import train_test_split, GridSearchCV, StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, roc_curve, confusion_matrix, classification_report
)
from sklearn.preprocessing import StandardScaler

# Local imports
import sys
sys.path.append('..')
from pipelines.embedding_pipeline import EmbeddingPipeline
from models.config_models import PipelineConfig
from utils.logging_utils import get_logger

# Set random seeds
np.random.seed(42)

# Configure plotting
plt.style.use('seaborn-v0_8')
%matplotlib inline

## 1. Generate or Load Embeddings

In [None]:
# Configuration for embedding pipeline
data_path = Path('data/medical_claims_complete.csv')
embeddings_path = Path('outputs/lightgbm_embeddings.csv')

# Check if embeddings already exist
if embeddings_path.exists():
    print("Loading existing embeddings...")
    embeddings_df = pd.read_csv(embeddings_path)
else:
    print("Generating new embeddings...")
    config = {
        'pipeline': {
            'job_name': 'lightgbm_embeddings',
            'log_level': 'INFO'
        },
        'data': {
            'data_path': str(data_path.absolute()),
            'claim_column': 'claim',
            'label_column': 'label',
            'mcid_column': 'mcid'
        },
        'llm': {
            'model_url': 'http://localhost:8000',
            'batch_size': 32,
            'max_retries': 3
        },
        'outputs': {
            'output_dir': 'outputs',
            'save_embeddings': True
        }
    }
    
    pipeline_config = PipelineConfig(**config)
    embedding_pipeline = EmbeddingPipeline(pipeline_config)
    embeddings_df = embedding_pipeline.run()
    embeddings_df.to_csv(embeddings_path, index=False)

print(f"Embeddings shape: {embeddings_df.shape}")
print(f"Columns: {embeddings_df.columns.tolist()[:10]}...")

## 2. Prepare Data for Training

In [None]:
# Extract features and labels
embedding_cols = [col for col in embeddings_df.columns if col.startswith('embedding_')]
X = embeddings_df[embedding_cols].values
y = embeddings_df['label'].values

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training set: {X_train.shape}")
print(f"Test set: {X_test.shape}")
print(f"Class distribution - Train: {np.bincount(y_train)}")
print(f"Class distribution - Test: {np.bincount(y_test)}")

# Create LightGBM datasets
train_data = lgb.Dataset(X_train, label=y_train)
valid_data = lgb.Dataset(X_test, label=y_test, reference=train_data)

## 3. Train Basic LightGBM Model

In [None]:
# Basic LightGBM parameters
params = {
    'objective': 'binary',
    'metric': ['binary_logloss', 'auc'],
    'boosting_type': 'gbdt',
    'num_leaves': 31,
    'learning_rate': 0.1,
    'feature_fraction': 0.9,
    'bagging_fraction': 0.8,
    'bagging_freq': 5,
    'verbose': 0,
    'seed': 42
}

# Train model
print("Training basic LightGBM model...")
lgb_model = lgb.train(
    params,
    train_data,
    valid_sets=[train_data, valid_data],
    valid_names=['train', 'eval'],
    num_boost_round=100,
    callbacks=[lgb.early_stopping(stopping_rounds=10), lgb.log_evaluation(10)]
)

# Make predictions
y_pred_proba = lgb_model.predict(X_test, num_iteration=lgb_model.best_iteration)
y_pred = (y_pred_proba > 0.5).astype(int)

# Calculate metrics
metrics = {
    'accuracy': accuracy_score(y_test, y_pred),
    'precision': precision_score(y_test, y_pred),
    'recall': recall_score(y_test, y_pred),
    'f1': f1_score(y_test, y_pred),
    'auc_roc': roc_auc_score(y_test, y_pred_proba)
}

print("\nBasic Model Performance:")
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")

## 4. Hyperparameter Tuning with Scikit-learn API

In [None]:
# Define parameter grid for LightGBM
param_grid = {
    'n_estimators': [100, 200, 300],
    'num_leaves': [31, 50, 70],
    'learning_rate': [0.01, 0.1, 0.3],
    'subsample': [0.8, 1.0],
    'colsample_bytree': [0.8, 1.0],
    'min_child_samples': [20, 30, 40]
}

# Create LightGBM classifier
lgbm_clf = LGBMClassifier(
    objective='binary',
    boosting_type='gbdt',
    metric='auc',
    random_state=42,
    n_jobs=-1
)

# Grid search with cross-validation
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
grid_search = GridSearchCV(
    lgbm_clf,
    param_grid,
    cv=cv,
    scoring='roc_auc',
    n_jobs=-1,
    verbose=1
)

print("Starting hyperparameter tuning...")
grid_search.fit(X_train, y_train)

# Best parameters
print(f"\nBest parameters: {grid_search.best_params_}")
print(f"Best CV score: {grid_search.best_score_:.4f}")

# Evaluate best model
best_model = grid_search.best_estimator_
y_pred_best = best_model.predict(X_test)
y_pred_proba_best = best_model.predict_proba(X_test)[:, 1]

metrics_best = {
    'accuracy': accuracy_score(y_test, y_pred_best),
    'precision': precision_score(y_test, y_pred_best),
    'recall': recall_score(y_test, y_pred_best),
    'f1': f1_score(y_test, y_pred_best),
    'auc_roc': roc_auc_score(y_test, y_pred_proba_best)
}

print("\nTuned Model Performance:")
for metric, value in metrics_best.items():
    print(f"{metric}: {value:.4f}")

## 5. Feature Importance Analysis

In [None]:
# Get feature importance from best model
importance = best_model.feature_importances_
indices = np.argsort(importance)[::-1]

# Plot top 20 features
plt.figure(figsize=(12, 8))
top_n = 20
plt.bar(range(top_n), importance[indices[:top_n]])
plt.xlabel('Feature Index')
plt.ylabel('Feature Importance (Split)')
plt.title('Top 20 Most Important Features (LightGBM)')
plt.tight_layout()
plt.show()

# Feature importance distribution
plt.figure(figsize=(10, 6))
plt.hist(importance, bins=50, edgecolor='black')
plt.xlabel('Feature Importance')
plt.ylabel('Count')
plt.title('Distribution of Feature Importances')
plt.tight_layout()
plt.show()

# LightGBM native feature importance plot
lgb.plot_importance(lgb_model, max_num_features=20, figsize=(10, 8), importance_type='gain')
plt.title('Feature Importance by Gain')
plt.tight_layout()
plt.show()

print(f"Top 10 feature indices: {indices[:10]}")
print(f"Top 10 importance scores: {importance[indices[:10]]}")

## 6. Model Evaluation and Visualization

In [None]:
# Confusion Matrix
cm = confusion_matrix(y_test, y_pred_best)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix - LightGBM')
plt.tight_layout()
plt.show()

# ROC Curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba_best)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, label=f'ROC curve (AUC = {metrics_best["auc_roc"]:.3f})')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - LightGBM')
plt.legend()
plt.tight_layout()
plt.show()

# Classification Report
print("\nDetailed Classification Report:")
print(classification_report(y_test, y_pred_best))

## 7. Training History Visualization

In [None]:
# Plot training metrics from basic model
results = lgb_model.evals_result_

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Log Loss
ax1.plot(results['train']['binary_logloss'], label='Train')
ax1.plot(results['eval']['binary_logloss'], label='Validation')
ax1.set_xlabel('Iteration')
ax1.set_ylabel('Log Loss')
ax1.set_title('LightGBM Training - Log Loss')
ax1.legend()

# AUC
ax2.plot(results['train']['auc'], label='Train')
ax2.plot(results['eval']['auc'], label='Validation')
ax2.set_xlabel('Iteration')
ax2.set_ylabel('AUC')
ax2.set_title('LightGBM Training - AUC')
ax2.legend()

plt.tight_layout()
plt.show()

print(f"Best iteration: {lgb_model.best_iteration}")
print(f"Best validation AUC: {max(results['eval']['auc']):.4f}")

## 8. Advanced LightGBM Features

In [None]:
# SHAP value analysis (if shap is installed)
try:
    import shap
    
    # Create SHAP explainer
    explainer = shap.Explainer(best_model, X_train)
    shap_values = explainer(X_test[:100])  # Calculate for subset
    
    # Summary plot
    shap.summary_plot(shap_values, X_test[:100], show=False)
    plt.title('SHAP Summary Plot - Top Features')
    plt.tight_layout()
    plt.show()
    
except ImportError:
    print("SHAP not installed. Install with: pip install shap")

# Learning curve analysis
train_sizes = [0.1, 0.3, 0.5, 0.7, 0.9]
train_scores = []
val_scores = []

for size in train_sizes:
    n_samples = int(size * len(X_train))
    X_subset = X_train[:n_samples]
    y_subset = y_train[:n_samples]
    
    model = LGBMClassifier(**grid_search.best_params_, random_state=42)
    model.fit(X_subset, y_subset)
    
    train_scores.append(model.score(X_subset, y_subset))
    val_scores.append(model.score(X_test, y_test))

plt.figure(figsize=(10, 6))
plt.plot(train_sizes, train_scores, 'o-', label='Training score')
plt.plot(train_sizes, val_scores, 'o-', label='Validation score')
plt.xlabel('Training Set Size (fraction)')
plt.ylabel('Accuracy Score')
plt.title('Learning Curve - LightGBM')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

## 9. Save Model and Results

In [None]:
# Create output directory
output_dir = Path('outputs/lightgbm_model')
output_dir.mkdir(parents=True, exist_ok=True)

# Save model (native format)
model_path = output_dir / f'lightgbm_model_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt'
lgb_model.save_model(str(model_path))
print(f"Native model saved to: {model_path}")

# Save sklearn API model
import joblib
sklearn_model_path = output_dir / f'lightgbm_sklearn_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pkl'
joblib.dump(best_model, sklearn_model_path)
print(f"Sklearn model saved to: {sklearn_model_path}")

# Save metrics and results
results = {
    'model_type': 'LightGBM',
    'timestamp': datetime.now().isoformat(),
    'best_parameters': grid_search.best_params_,
    'cv_score': float(grid_search.best_score_),
    'test_metrics': metrics_best,
    'feature_importance': {
        'top_features': indices[:20].tolist(),
        'importance_scores': importance[indices[:20]].tolist()
    },
    'training_info': {
        'best_iteration': lgb_model.best_iteration,
        'num_trees': best_model.n_estimators,
        'num_leaves': best_model.num_leaves
    },
    'data_info': {
        'n_train': len(X_train),
        'n_test': len(X_test),
        'n_features': X_train.shape[1],
        'class_distribution': {
            'train': np.bincount(y_train).tolist(),
            'test': np.bincount(y_test).tolist()
        }
    }
}

metrics_path = output_dir / f'lightgbm_metrics_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
with open(metrics_path, 'w') as f:
    json.dump(results, f, indent=2)
print(f"Metrics saved to: {metrics_path}")

# Display summary
print("\n=== LightGBM Model Summary ===")
print(f"Best AUC-ROC: {metrics_best['auc_roc']:.4f}")
print(f"Best F1 Score: {metrics_best['f1']:.4f}")
print(f"Number of trees: {best_model.n_estimators}")
print(f"Number of leaves: {best_model.num_leaves}")
print(f"Learning rate: {best_model.learning_rate}")
print(f"Best iteration (early stopping): {lgb_model.best_iteration}")

## 10. Model Comparison Summary

In [None]:
# Create comparison visualization
comparison_data = {
    'Metric': ['Accuracy', 'Precision', 'Recall', 'F1 Score', 'AUC-ROC'],
    'Basic LightGBM': [metrics[k] for k in ['accuracy', 'precision', 'recall', 'f1', 'auc_roc']],
    'Tuned LightGBM': [metrics_best[k] for k in ['accuracy', 'precision', 'recall', 'f1', 'auc_roc']]
}

comparison_df = pd.DataFrame(comparison_data)
comparison_df = comparison_df.set_index('Metric')

# Plot comparison
ax = comparison_df.plot(kind='bar', figsize=(10, 6), rot=0)
ax.set_ylabel('Score')
ax.set_title('LightGBM Model Performance Comparison')
ax.legend(loc='lower right')
ax.set_ylim(0, 1.05)

# Add value labels on bars
for container in ax.containers:
    ax.bar_label(container, fmt='%.3f')

plt.tight_layout()
plt.show()

# Print improvement summary
print("\n=== Performance Improvement ===")
for metric in ['accuracy', 'precision', 'recall', 'f1', 'auc_roc']:
    improvement = (metrics_best[metric] - metrics[metric]) * 100
    print(f"{metric.capitalize()}: {improvement:+.2f}%")