# Decision Tree Training

This notebook trains a Decision Tree classifier on the CICIDS2017 intrusion detection dataset.

**Key Features:**
- SMOTE balancing
- Tree complexity analysis
- Feature importance visualization
- Decision rules extraction
- Tree visualization

**Advantages of Decision Trees:**
- Fast training and prediction
- Interpretable results
- Handles non-linear relationships
- No feature scaling required

## 1. Setup and Imports

In [None]:
import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from time import time

# Add project root to path
root_dir = os.getcwd().split("AdversarialNIDS")[0] + "AdversarialNIDS"
sys.path.append(root_dir)

from CICIDS2017.dataset import CICIDS2017
from UNSWNB15.dataset import UNSWNB15

# Import shared utilities
from scripts.models.model_utils import (
    check_data_leakage,
    get_tree_feature_importance,
)

# Import model-specific modules
from scripts.models.decision_tree.decision_tree import train_decision_tree
from scripts.analysis.model_analysis import perform_model_analysis
from scripts.logger import LoggerManager

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("Imports successful")

## 2. Initialize Logger

In [None]:
logger = LoggerManager(log_name="dt_notebook").get_logger()
logger.info("Starting Decision Tree training notebook")

## 3. Load and Preprocess Data

In [None]:
# Load dataset
logger.info("Loading CICIDS2017 dataset...")
dataset = CICIDS2017(logger=logger).optimize_memory().encode().scale().subset(size=100000, multi_class=True)

## 4. Visualize Class Distribution

In [None]:
#TO_DO

## 5. Data Leakage Check

In [None]:
#TO_DO
# diagnostics = check_data_leakage(X, y, logger=logger)

## 6. Train/Test Split

In [None]:
# Split data
X_train, X_test, y_train, y_test = dataset.split(test_size=0.2, apply_smote=True)
print(f"Train set shape: {X_train.shape}")
print(f"Test set shape: {X_test.shape}")

## 7. Cross-Validation and training 

In [None]:
# Unpack model and CV scores from train_decision_tree
dt_model, cv_scores = train_decision_tree(
    X_train,
    y_train,
    max_depth=3,
    min_samples_split=10,
    min_samples_leaf=5,
    criterion='gini',
    max_features=None,
    class_weight='balanced',
    cv_test=False,
    cv=3,
    random_state=0,
    logger=logger
)
if cv_scores!= None:
    print("\n" + "="*50)
    print("CROSS-VALIDATION RESULTS")
    print("="*50)
    print(f"CV Scores: {cv_scores}")
    print(f"Mean CV Score: {cv_scores.mean():.4f} (+/- {cv_scores.std():.4f})")
    # Plot CV scores
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(cv_scores)+1), cv_scores, marker='o', markersize=10, linewidth=2, color='green')
    plt.axhline(y=cv_scores.mean(), color='r', linestyle='--', 
                label=f'Mean: {cv_scores.mean():.4f}')
    plt.xlabel('Fold', fontsize=12)
    plt.ylabel('Accuracy', fontsize=12)
    plt.title('Decision Tree Cross-Validation Scores', fontsize=14, fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

## 8. Analyze Tree Complexity

In [None]:
from scripts.models.decision_tree.analyze_tree import analyze_tree_complexity 
# Analyze tree complexity
complexity = analyze_tree_complexity(dt_model, logger=logger)

# Visualize complexity
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart of tree metrics
metrics = ['Total Nodes', 'Leaf Nodes', 'Max Depth', 'Features Used']
values = [complexity['n_nodes'], complexity['n_leaves'], 
          complexity['max_depth'], complexity['n_features_used']]

axes[0].bar(metrics, values, color=['skyblue', 'lightgreen', 'salmon', 'gold'])
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title('Tree Complexity Metrics', fontsize=14, fontweight='bold')
axes[0].tick_params(axis='x', rotation=45)

# Pie chart of node distribution
internal_nodes = complexity['n_nodes'] - complexity['n_leaves']
axes[1].pie([internal_nodes, complexity['n_leaves']], 
            labels=['Internal Nodes', 'Leaf Nodes'],
            autopct='%1.1f%%', startangle=90,
            colors=['lightcoral', 'lightgreen'])
axes[1].set_title('Node Distribution', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

## 9. Evaluate on Test Set and Confusion Matrix

In [None]:
# Evaluate model
cm, cr = perform_model_analysis(
    model=dt_model,
    X_test=X_test,
    y_test=y_test,
    logger=logger,
    model_name="DecisionTree",
    dir=os.getcwd(),
    plot=True
)

## 10. Feature Importance Analysis

In [None]:
# Get feature importance
top_features = get_tree_feature_importance(
    dt_model,
    feature_names=list(dataset.data.columns),
    top_n=15,
    logger=logger
)

# Plot feature importance
features, importances = zip(*top_features)

plt.figure(figsize=(12, 8))
colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(features)))
plt.barh(range(len(features)), importances, color=colors)
plt.yticks(range(len(features)), features)
plt.xlabel('Importance', fontsize=12)
plt.title('Top 15 Feature Importances', fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

## 16. Extract Decision Rules

In [None]:
from scripts.models.decision_tree.tree_rules import *
# Use the correct feature names: drop the label column (e.g., 'Attack Type') if present
feature_names = list(dataset.data.columns.drop('Attack Type'))
rules = get_tree_rules(dt_model, feature_names, max_depth=3)

print("\n" + "="*70)
print("DECISION RULES (Top 3 Levels)")
print("="*70)
print(rules)
print("\n... (tree continues deeper)")

## Tips for Improving Decision Tree Performance

### To Reduce Overfitting:
1. **Limit max_depth** - Prevents very deep, overfit trees
2. **Increase min_samples_split** - Requires more samples before splitting
3. **Increase min_samples_leaf** - Requires more samples at leaf nodes
4. **Use max_features** - Adds randomness (closer to Random Forest)
5. **Prune the tree** - Post-pruning using cost complexity pruning

### To Improve Performance:
1. **Use Random Forest** - Ensemble of trees usually performs better
2. **Try criterion='entropy'** - May work better than 'gini' for some datasets
3. **Feature engineering** - Create more informative features
4. **Handle class imbalance** - Use class_weight='balanced' or SMOTE
5. **Grid search** - Find optimal hyperparameters

### Interpretability vs Performance:
- **Shallow trees** (depth 5-10): More interpretable, may underfit
- **Deep trees** (depth 20+): Better performance, less interpretable, may overfit
- **Random Forest**: Best performance, harder to interpret