# Decision Tree Analysis for CJD Subtype Classification

This notebook implements a decision tree classifier to differentiate between CJD (Creutzfeldt-Jakob Disease) subtypes using protein biomarker data. The analysis focuses on creating an interpretable model that could potentially be used in clinical settings.

## Key Components

1. Data Preparation
   - Uses pre-selected top 10 biomarkers from previous feature importance analysis
   - Excludes control samples, focusing only on disease subtypes (MM(V)1, MV2K, VV2)

2. Model Development
   - Implements a decision tree with controlled complexity (max_depth=4, min_samples_leaf=5)
   - Uses 5-fold cross-validation to assess model robustness
   - Evaluates performance using confusion matrices and classification reports

3. Visualization
   - Decision tree structure visualization
   - Feature importance analysis
   - Cross-validation performance metrics
   - Sample distribution in decision nodes

The goal is to create a simple, interpretable decision-making tool that could assist in CJD subtype classification using a minimal set of biomarkers. This approach prioritizes clinical interpretability over complex model architectures, making it potentially valuable for practical diagnostic support.

In [1]:
# Standard library imports
import os
import warnings

# Core data processing and numerical computations
import numpy as np
import pandas as pd

# Sklearn imports - organized by functionality
from sklearn.model_selection import (
    cross_val_score,
    cross_val_predict,
    StratifiedKFold
)
from sklearn.tree import (
    DecisionTreeClassifier,
    plot_tree,
    export_text
)
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    roc_curve,
    auc,
    precision_recall_curve
)

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Suppress warnings
warnings.filterwarnings("ignore")

In [2]:
# Define paths
data_path = os.path.dirname(os.getcwd()) + '/data'
figure_path = os.path.dirname(os.getcwd()) + '/figures/decision_tree'

# Import and prepare data
df = pd.read_excel(data_path + '/curated/olink.xlsx')
df = df[df['SubGroup'] != 'CTRL']

# Get top features
feature_importance_rankings = pd.read_csv(data_path + '/results/feature_importance_rankings.csv')
top_biomarkers = list(feature_importance_rankings['Feature'].head(10))

# Select columns
columns_to_select = top_biomarkers + ['SubGroup']
df = df[columns_to_select]

# Prepare features and target
X = df[top_biomarkers]
y = df['SubGroup']

In [None]:
# Initialize the model
dt = DecisionTreeClassifier(
    max_depth=4,
    min_samples_leaf=5,
    random_state=42
)

# Perform cross-validation
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
cv_scores = cross_val_score(dt, X, y, cv=cv)
print(f"\nCross-validation scores: {cv_scores}")
print(f"Average CV score: {cv_scores.mean():.3f} (+/- {cv_scores.std() * 2:.3f})")

# Get predictions from cross-validation
y_pred_cv = cross_val_predict(dt, X, y, cv=cv)

# Print classification report from CV predictions
print("\nClassification Report (from cross-validation):")
print(classification_report(y, y_pred_cv))

# Create confusion matrix from CV predictions
plt.figure(figsize=(10,8))
cm = confusion_matrix(y, y_pred_cv, normalize='true')
sns.heatmap(cm, 
            annot=True, 
            fmt='.2%',
            cmap='Blues',
            xticklabels=sorted(y.unique()),
            yticklabels=sorted(y.unique()), 
            annot_kws={"size": 12})
plt.title('Confusion Matrix', fontsize=14)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.savefig(figure_path + '/decision_tree_confusion_matrix.png', dpi=1200, bbox_inches='tight')
plt.close()

# Fit final model on all data
dt.fit(X, y)

# Feature importance analysis and visualization
feature_importance = pd.DataFrame({
    'feature': top_biomarkers,
    'importance': dt.feature_importances_
}).sort_values('importance', ascending=False)

print("\nFeature Importances in Decision Tree:")
print(feature_importance)

# Plot feature importances
plt.figure(figsize=(10, 6))
sns.barplot(data=feature_importance, x='importance', y='feature')
plt.title('Feature Importance', fontsize=14)
plt.xlabel('Importance score')
plt.ylabel('')
plt.yticks(fontsize=12)
plt.tight_layout()
plt.savefig(figure_path + '/feature_importance.png', dpi=1200, bbox_inches='tight')
plt.show()

# Visualize the tree
plt.figure(figsize=(20,10))
plot_tree(dt, 
          feature_names=top_biomarkers,
          class_names=list(sorted(y.unique())),
          filled=True,
          rounded=True,
          fontsize=12)
plt.savefig(figure_path + '/decision_tree_viz.png', dpi=1200, bbox_inches='tight')
plt.show()

# Export text representation of the tree
tree_text = export_text(dt, 
                       feature_names=top_biomarkers,
                       show_weights=True)

with open(data_path + '/results/decision_tree_rules.txt', 'w') as f:
    f.write(tree_text)

In [None]:
# Function to get decision path for a single sample
def get_decision_path(sample_values):
    X_sample = pd.DataFrame([sample_values])[top_biomarkers]
    node_indicator = dt.decision_path(X_sample)
    leaf_id = dt.apply(X_sample)
    
    feature = dt.tree_.feature
    threshold = dt.tree_.threshold
    
    node_index = node_indicator.indices[node_indicator.indptr[0]:
                                      node_indicator.indptr[1]]
    
    steps = []
    for node_id in node_index:
        if leaf_id[0] == node_id:
            steps.append(f"Final prediction: {dt.predict(X_sample)[0]}")
        else:
            feature_id = feature[node_id]
            feature_name = top_biomarkers[feature_id]
            threshold_value = threshold[node_id]
            actual_value = X_sample[feature_name].values[0]
            decision = "â‰¤" if actual_value <= threshold_value else ">"
            
            steps.append(
                f"{feature_name} = {actual_value:.2f} {decision} {threshold_value:.2f}"
            )
    
    return steps

# Additional visualization: Sample distribution in nodes
def plot_node_sample_distribution(dt, X, y):
    """Plot the distribution of samples in each node of the tree."""
    n_nodes = dt.tree_.node_count
    children_left = dt.tree_.children_left
    children_right = dt.tree_.children_right
    feature = dt.tree_.feature
    threshold = dt.tree_.threshold
    
    node_depths = np.zeros(shape=n_nodes, dtype=np.int64)
    stack = [(0, 0)]  # (node_id, depth)
    while len(stack) > 0:
        node_id, depth = stack.pop()
        node_depths[node_id] = depth
        
        if children_left[node_id] != children_right[node_id]:
            stack.append((children_left[node_id], depth + 1))
            stack.append((children_right[node_id], depth + 1))
            
    plt.figure(figsize=(20,10))
    for i in range(n_nodes):
        if children_left[i] == children_right[i]:  # leaf
            continue
            
        samples_left = X[dt.decision_path(X).toarray()[:, children_left[i]].astype(bool)]
        samples_right = X[dt.decision_path(X).toarray()[:, children_right[i]].astype(bool)]
        
        plt.subplot(dt.get_depth(), 2, node_depths[i] * 2 + 1)
        if len(samples_left) > 0:
            plt.hist(samples_left[top_biomarkers[feature[i]]], alpha=0.6, label='Left')
        if len(samples_right) > 0:
            plt.hist(samples_right[top_biomarkers[feature[i]]], alpha=0.6, label='Right')
        plt.axvline(threshold[i], color='r', linestyle='--')
        plt.title(f'Node {i}: {top_biomarkers[feature[i]]}')
        plt.legend()
    
    plt.tight_layout()
    plt.savefig(figure_path + '/node_distributions.png', dpi=1200, bbox_inches='tight')
    plt.show()

# Plot node distributions
plot_node_sample_distribution(dt, X, y)

# Example: Get decision path for a sample
example_sample = X.iloc[0].to_dict()
path = get_decision_path(example_sample)
print("\nExample Decision Path (for first sample):")
print(f"True label: {y.iloc[0]}")
for step in path:
    print(f"- {step}")