# Decision Tree Training on CICIDS2017 Dataset

This notebook trains a Decision Tree classifier on the CICIDS2017 intrusion detection dataset.

**Key Features:**
- SMOTE balancing applied within CV pipeline
- Tree complexity analysis
- Feature importance visualization
- Decision rules extraction
- Hyperparameter tuning
- Tree visualization

**Advantages of Decision Trees:**
- Fast training and prediction
- Interpretable results
- Handles non-linear relationships
- No feature scaling required

## 1. Setup and Imports

In [None]:
import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from time import time

# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', '..')))

from CICIDS2017.preprocessing.dataset import CICIDS2017
# Import shared utilities
from scripts.models.model_utils import (
    prepare_data,
    evaluate_model,
    check_data_leakage,
    get_feature_importance,
    balance_classes_info,
    remove_rare_classes,
    print_performance_summary,
    remove_low_variance_features
)

# Import model-specific modules
from scripts.models.decision_tree.decision_tree import train_decision_tree

from scripts.logger import LoggerManager

from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.tree import plot_tree

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("‚úì Imports successful")

## 2. Initialize Logger

In [None]:
logger = LoggerManager(log_name="dt_notebook").get_logger()
logger.info("Starting Decision Tree training notebook")

## 3. Load and Preprocess Data

In [None]:
# Load dataset
logger.info("Loading CICIDS2017 dataset...")
dataset = CICIDS2017(logger=logger)
dataset.encode().optimize_memory()
data = dataset.data

print(f"Dataset shape: {data.shape}")
data.head()

## 4. Sample Data

In [None]:
# Decision trees are fast, so we can use a larger sample
SAMPLE_SIZE = 200000

logger.info(f"Sampling {SAMPLE_SIZE} rows from dataset...")
data_sample = data.sample(n=min(SAMPLE_SIZE, len(data)), random_state=0)

print(f"Sampled data shape: {data_sample.shape}")

## 5. Prepare Features and Labels

In [None]:
# Split features and labels
X = data_sample.drop('Attack Type', axis=1)
y = data_sample['Attack Type']

# Remove known leakage features
leakage_features = ['Attack Number']
existing_leakage = [f for f in leakage_features if f in X.columns]

if existing_leakage:
    logger.warning(f"üö® REMOVING LEAKAGE FEATURES: {existing_leakage}")
    X = X.drop(columns=existing_leakage)

# Convert to numeric
X = X.apply(pd.to_numeric, errors='coerce')

# Handle missing values
if X.isnull().sum().sum() > 0:
    n_missing = X.isnull().sum().sum()
    logger.info(f"Filling {n_missing} missing values with 0")
    X = X.fillna(0)

# Remove low variance features
X, removed_features = remove_low_variance_features(X, threshold=0.01, logger=logger)

print(f"Feature matrix shape: {X.shape}")
print(f"\nClass distribution:")
print(y.value_counts())

## 6. Visualize Class Distribution

In [None]:
# Plot class distribution
plt.figure(figsize=(12, 6))
y.value_counts().plot(kind='bar', color='skyblue', edgecolor='black')
plt.title('Class Distribution (Before SMOTE)', fontsize=14, fontweight='bold')
plt.xlabel('Attack Type', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

## 7. Data Leakage Check

In [None]:
diagnostics = check_data_leakage(X, y, logger=logger)

## 8. Train/Test Split

In [None]:
# Remove classes with fewer than 2 samples
class_counts = y.value_counts()
valid_classes = class_counts[class_counts >= 2].index
X = X[y.isin(valid_classes)]
y = y[y.isin(valid_classes)]

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.25,
    random_state=0,
    stratify=y
)

print(f"Train set shape: {X_train.shape}")
print(f"Test set shape: {X_test.shape}")

## 9. Create Decision Tree and SMOTE

In [None]:
# Apply SMOTE before training (outside pipeline)
from imblearn.over_sampling import SMOTE


smote = SMOTE(random_state=0)
X_train_res, y_train_res = smote.fit_resample(X_train, y_train)

print("SMOTE-applied data.")

## 10. Cross-Validation

In [None]:
# Unpack model and CV scores from train_decision_tree
dt_model, cv_scores = train_decision_tree(
    X_train_res,
    y_train_res,
    max_depth=3,
    min_samples_split=10,
    min_samples_leaf=5,
    criterion='gini',
    max_features=None,
    class_weight='balanced',
    random_state=0,
    logger=logger
)

print("\n" + "="*50)
print("CROSS-VALIDATION RESULTS")
print("="*50)
print(f"CV Scores: {cv_scores}")
print(f"Mean CV Score: {cv_scores.mean():.4f} (+/- {cv_scores.std():.4f})")

# Plot CV scores
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(cv_scores)+1), cv_scores, marker='o', markersize=10, linewidth=2, color='green')
plt.axhline(y=cv_scores.mean(), color='r', linestyle='--', 
            label=f'Mean: {cv_scores.mean():.4f}')
plt.xlabel('Fold', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.title('Decision Tree Cross-Validation Scores', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 11. Train Final Model

In [None]:
# Train on full training set
logger.info("Training final model on full training set...")
start_time = time()

dt_model.fit(X_train_res, y_train_res)

training_time = time() - start_time

print(f"‚úì Model training completed in {training_time:.2f} seconds")

## 12. Analyze Tree Complexity

In [None]:
from scripts.models.decision_tree.analyze_tree import analyze_tree_complexity 
# Analyze tree complexity
complexity = analyze_tree_complexity(dt_model, logger=logger)

# Visualize complexity
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart of tree metrics
metrics = ['Total Nodes', 'Leaf Nodes', 'Max Depth', 'Features Used']
values = [complexity['n_nodes'], complexity['n_leaves'], 
          complexity['max_depth'], complexity['n_features_used']]

axes[0].bar(metrics, values, color=['skyblue', 'lightgreen', 'salmon', 'gold'])
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title('Tree Complexity Metrics', fontsize=14, fontweight='bold')
axes[0].tick_params(axis='x', rotation=45)

# Pie chart of node distribution
internal_nodes = complexity['n_nodes'] - complexity['n_leaves']
axes[1].pie([internal_nodes, complexity['n_leaves']], 
            labels=['Internal Nodes', 'Leaf Nodes'],
            autopct='%1.1f%%', startangle=90,
            colors=['lightcoral', 'lightgreen'])
axes[1].set_title('Node Distribution', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

## 13. Evaluate on Test Set

In [None]:
# Evaluate model
results = evaluate_model(dt_model, X_test, y_test, logger=logger)

print("\n" + "="*50)
print("TEST SET RESULTS")
print("="*50)
print(f"Test Accuracy: {results['accuracy']:.4f}")
print(f"\nClassification Report:")
print(results['report'])

## 14. Confusion Matrix

In [None]:
# Plot confusion matrix
plt.figure(figsize=(12, 10))
disp = ConfusionMatrixDisplay(confusion_matrix=results['confusion_matrix'],
                               display_labels=dt_model.classes_)
disp.plot(cmap='Greens', xticks_rotation=45)
plt.title('Decision Tree Confusion Matrix', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 15. Feature Importance Analysis

In [None]:
# Get feature importance
top_features = get_feature_importance(
    dt_model,
    feature_names=list(X.columns),
    top_n=15,
    logger=logger
)

# Plot feature importance
features, importances = zip(*top_features)

plt.figure(figsize=(12, 8))
colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(features)))
plt.barh(range(len(features)), importances, color=colors)
plt.yticks(range(len(features)), features)
plt.xlabel('Importance', fontsize=12)
plt.title('Top 15 Feature Importances', fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

## 16. Extract Decision Rules

In [None]:
from scripts.models.decision_tree.tree_rules import *
rules = get_tree_rules(dt_model, feature_names=list(X.columns), max_depth=3)

print("\n" + "="*70)
print("DECISION RULES (Top 3 Levels)")
print("="*70)
print(rules)
print("\n... (tree continues deeper)")

## 17. Visualize Tree (Top Levels)

We'll visualize only the top levels of the tree for clarity.

In [None]:
# Visualize the tree (top 3 levels only)
dt_model = dt_model.named_steps['dt']

plt.figure(figsize=(20, 10))
plot_tree(
    dt_model,
    feature_names=list(X.columns),
    class_names=[str(c) for c in dt_model.classes_],
    filled=True,
    rounded=True,
    max_depth=3,  # Only show top 3 levels
    fontsize=10
)
plt.title('Decision Tree Visualization (Top 3 Levels)', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"\nüìù Note: Full tree has {complexity['max_depth']} levels. ")
print(f"   Only showing top 3 levels for clarity.")

## 18. Performance Summary

In [None]:
print("\n" + "="*70)
print("FINAL PERFORMANCE SUMMARY")
print("="*70)
print(f"Mean CV Score: {cv_scores.mean():.4f} (+/- {cv_scores.std():.4f})")
print(f"Test Accuracy: {results['accuracy']:.4f}")
print(f"\nModel Configuration:")
print(f"  - Max depth: {dt_model.named_steps['dt'].max_depth}")
print(f"  - Criterion: {dt_model.named_steps['dt'].criterion}")
print(f"  - Min samples split: {dt_model.named_steps['dt'].min_samples_split}")
print(f"  - Min samples leaf: {dt_model.named_steps['dt'].min_samples_leaf}")
print(f"  - SMOTE: Enabled")
print(f"\nTree Complexity:")
print(f"  - Total nodes: {complexity['n_nodes']}")
print(f"  - Leaf nodes: {complexity['n_leaves']}")
print(f"  - Actual depth: {complexity['max_depth']}")
print(f"  - Features used: {complexity['n_features_used']}/{X.shape[1]}")
print(f"\nTiming:")
print(f"  - Training time: {training_time:.2f}s")

# Performance indicators
if cv_scores.mean() > 0.99:
    print("\n‚ö†Ô∏è  WARNING: CV score > 0.99 may indicate data leakage or overfitting!")
elif cv_scores.mean() >= 0.95:
    print("\n‚úì Excellent performance achieved (CV score ‚â• 0.95)")
elif cv_scores.mean() >= 0.90:
    print("\n‚úì Good performance achieved (CV score ‚â• 0.90)")
else:
    print("\n‚ö†Ô∏è  Performance below 0.90")
    print("   Consider:")
    print("   - Increasing max_depth")
    print("   - Decreasing min_samples_split/min_samples_leaf")
    print("   - Trying criterion='entropy' instead of 'gini'")
    print("   - Using Random Forest instead (ensemble of trees)")

# Check for overfitting
if abs(cv_scores.mean() - results['accuracy']) > 0.05:
    print("\n‚ö†Ô∏è  Warning: Large gap between CV and test accuracy")
    print(f"   CV: {cv_scores.mean():.4f}, Test: {results['accuracy']:.4f}")
    print("   This may indicate overfitting. Consider:")
    print("   - Reducing max_depth")
    print("   - Increasing min_samples_split/min_samples_leaf")
    print("   - Using max_features to add randomness")

logger.info("Notebook execution completed successfully!")

## Tips for Improving Decision Tree Performance

### To Reduce Overfitting:
1. **Limit max_depth** - Prevents very deep, overfit trees
2. **Increase min_samples_split** - Requires more samples before splitting
3. **Increase min_samples_leaf** - Requires more samples at leaf nodes
4. **Use max_features** - Adds randomness (closer to Random Forest)
5. **Prune the tree** - Post-pruning using cost complexity pruning

### To Improve Performance:
1. **Use Random Forest** - Ensemble of trees usually performs better
2. **Try criterion='entropy'** - May work better than 'gini' for some datasets
3. **Feature engineering** - Create more informative features
4. **Handle class imbalance** - Use class_weight='balanced' or SMOTE
5. **Grid search** - Find optimal hyperparameters

### Interpretability vs Performance:
- **Shallow trees** (depth 5-10): More interpretable, may underfit
- **Deep trees** (depth 20+): Better performance, less interpretable, may overfit
- **Random Forest**: Best performance, harder to interpret