# Improve Models and Prevent Overfitting

This notebook provides techniques to improve model performance and detect/prevent overfitting.

## ⚠️ Before Running

**IMPORTANT**: You must update the dataset path in **Cell 4** before running this notebook.

1. Open Cell 4 (Load and Prepare Data)
2. Update `dataset_path` to point to your dataset file
3. Example: `dataset_path = r'd:\drone_firmware_full_mitre_dataset.csv'`

The notebook will automatically:
- Detect MITRE dataset format and convert if needed
- Find the target column (`clean_label`, `classification`, etc.)
- Handle data preprocessing and scaling
- Provide helpful error messages if something goes wrong


## 1. Import Libraries


In [2]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns


## 2. Load and Prepare Data

**⚠️ IMPORTANT**: Update the dataset path below to point to your actual dataset file.

**Example paths:**
- MITRE dataset: `r'd:\drone_firmware_full_mitre_dataset.csv'`
- Local dataset: `r'path\to\your\dataset.csv'`
- Relative path: `'../data/your_dataset.csv'`


In [5]:
# ⚠️ UPDATE THIS PATH TO YOUR DATASET ⚠️
# Example: dataset_path = r'd:\drone_firmware_full_mitre_dataset.csv'
# Or use a relative path: dataset_path = '../data/your_dataset.csv'
dataset_path = r'd:\drone_firmware_full_mitre_dataset.csv'  # Change this!

try:
    # Load your dataset
    print(f"Loading dataset from: {dataset_path}")
    df = pd.read_csv(dataset_path)
    print(f"✓ Dataset loaded successfully: {len(df)} rows, {len(df.columns)} columns")
    print(f"  Columns: {list(df.columns)[:10]}...")  # Show first 10 columns
    
    # Check if this is a MITRE dataset and convert if needed
    try:
        from mitre_dataset_converter import detect_mitre_dataset, convert_mitre_dataset_to_system_format
        if detect_mitre_dataset(df):
            print("✓ Detected MITRE dataset format, converting...")
            df = convert_mitre_dataset_to_system_format(df)
            print(f"✓ Converted MITRE dataset: {len(df)} rows")
    except ImportError:
        print("⚠ MITRE converter not available, using dataset as-is")
    except Exception as e:
        print(f"⚠ MITRE conversion warning: {e}")
    
    # Determine target column
    target_col = None
    possible_targets = ['clean_label', 'is_tampered', 'classification', 'label', 'target']
    for col in possible_targets:
        if col in df.columns:
            target_col = col
            break
    
    if target_col is None:
        raise ValueError(
            f"No target column found! Looking for: {possible_targets}\n"
            f"Available columns: {list(df.columns)}"
        )
    
    print(f"✓ Using target column: {target_col}")
    
    # Separate features and target
    drop_cols = [target_col]
    if 'firmware_id' in df.columns:
        drop_cols.append('firmware_id')
    
    X = df.drop(columns=drop_cols)
    y = df[target_col]
    
    print(f"✓ Features: {X.shape[1]} columns")
    print(f"✓ Target distribution: {y.value_counts().to_dict()}")
    
    # Convert to binary if needed
    if y.dtype == 'object':
        # Handle MITRE classification
        if 'Untampered' in y.values or 'untampered' in y.values:
            y = (y == 'Untampered').astype(int)  # 1 = clean, 0 = tampered
            print("✓ Converted classification to binary (Untampered=1, others=0)")
        else:
            # Try other mappings
            unique_vals = y.unique()
            print(f"  Unique values: {unique_vals}")
            # Default: assume first value is clean
            y = (y == unique_vals[0]).astype(int)
            print(f"✓ Converted to binary (first value={unique_vals[0]}=1, others=0)")
    
    # Check if we have enough samples for stratified split
    if len(y.unique()) < 2:
        raise ValueError(f"Target has only one class: {y.unique()}. Need at least 2 classes.")
    
    # Split data
    try:
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42, stratify=y
        )
        print(f"✓ Data split: Train={len(X_train)}, Test={len(X_test)}")
    except ValueError as e:
        print(f"⚠ Stratified split failed (likely imbalanced classes): {e}")
        print("  Using non-stratified split...")
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42
        )
        print(f"✓ Data split: Train={len(X_train)}, Test={len(X_test)}")
    
    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    print("✓ Features scaled successfully")
    
except FileNotFoundError:
    print(f"❌ ERROR: File not found: {dataset_path}")
    print("\nPlease update 'dataset_path' in the cell above to point to your dataset.")
    print("\nExample paths:")
    print("  - MITRE dataset: r'd:\\drone_firmware_full_mitre_dataset.csv'")
    print("  - Windows path: r'C:\\Users\\YourName\\Documents\\dataset.csv'")
    print("  - Relative path: '../data/dataset.csv'")
    raise
except Exception as e:
    print(f"❌ ERROR: {type(e).__name__}: {e}")
    import traceback
    traceback.print_exc()
    raise


Loading dataset from: d:\drone_firmware_full_mitre_dataset.csv
❌ ERROR: File not found: d:\drone_firmware_full_mitre_dataset.csv

Please update 'dataset_path' in the cell above to point to your dataset.

Example paths:
  - MITRE dataset: r'd:\drone_firmware_full_mitre_dataset.csv'
  - Windows path: r'C:\Users\YourName\Documents\dataset.csv'
  - Relative path: '../data/dataset.csv'


FileNotFoundError: [Errno 2] No such file or directory: 'd:\\drone_firmware_full_mitre_dataset.csv'

## 3. Detect Overfitting


In [None]:
def detect_overfitting(model, X_train, y_train, X_test, y_test):
    """Detect if model is overfitting"""
    train_pred = model.predict(X_train)
    test_pred = model.predict(X_test)
    
    train_acc = accuracy_score(y_train, train_pred)
    test_acc = accuracy_score(y_test, test_pred)
    
    gap = train_acc - test_acc
    
    print(f"Train Accuracy: {train_acc:.4f}")
    print(f"Test Accuracy:  {test_acc:.4f}")
    print(f"Accuracy Gap:   {gap:.4f}")
    
    if gap > 0.15:
        print("⚠ WARNING: Model is overfitting!")
    elif test_acc > 0.99:
        print("⚠ WARNING: Suspiciously high accuracy - may be overfitting!")
    else:
        print("✓ Model looks good")
    
    return train_acc, test_acc, gap


## 4. Use Cross-Validation


In [None]:
def evaluate_with_cv(model, X, y, cv_folds=5):
    """Evaluate model with cross-validation"""
    cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=42)
    cv_scores = cross_val_score(model, X, y, cv=cv, scoring='accuracy')
    
    print(f"Cross-Validation Scores: {cv_scores}")
    print(f"Mean: {cv_scores.mean():.4f} ({cv_scores.mean()*100:.2f}%)")
    print(f"Std:  {cv_scores.std():.4f}")
    print(f"Range: [{cv_scores.min():.4f}, {cv_scores.max():.4f}]")
    
    # More realistic accuracy
    realistic_accuracy = cv_scores.mean()
    
    return realistic_accuracy, cv_scores


## 5. Improve Random Forest (Reduce Overfitting)


In [None]:
# Regularized Random Forest
rf_improved = RandomForestClassifier(
    n_estimators=100,
    max_depth=10,  # Limit depth to prevent overfitting
    min_samples_split=10,  # Require more samples to split
    min_samples_leaf=5,  # Require more samples in leaf
    max_features='sqrt',  # Use sqrt of features
    random_state=42,
    n_jobs=-1
)

rf_improved.fit(X_train_scaled, y_train)

# Evaluate
train_acc, test_acc, gap = detect_overfitting(rf_improved, X_train_scaled, y_train, X_test_scaled, y_test)
realistic_acc, cv_scores = evaluate_with_cv(rf_improved, X_train_scaled, y_train)

print(f"\nRecommended Accuracy: {realistic_acc:.4f} ({realistic_acc*100:.2f}%)")


## 6. Feature Importance Analysis


# Get feature importance
feature_importance = pd.DataFrame({
    'feature': X.columns,
    'importance': rf_improved.feature_importances_
}).sort_values('importance', ascending=False)

print("Top 10 Most Important Features:")
print(feature_importance.head(10))

# Visualize
plt.figure(figsize=(10, 6))
sns.barplot(data=feature_importance.head(20), x='importance', y='feature')
plt.title('Top 20 Feature Importances')
plt.tight_layout()
plt.show()


## 7. Confusion Matrix and Classification Report


In [None]:
# Get feature importance
feature_importance = pd.DataFrame({
    'feature': X.columns,
    'importance': rf_improved.feature_importances_
}).sort_values('importance', ascending=False)

print("Top 10 Most Important Features:")
print(feature_importance.head(10))

# Visualize
plt.figure(figsize=(10, 6))
sns.barplot(data=feature_importance.head(20), x='importance', y='feature')
plt.title('Top 20 Feature Importances')
plt.tight_layout()
plt.show()


## 8. Update MODEL_ACCURACIES in main.py


# Predictions
y_pred = rf_improved.predict(X_test_scaled)

# Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:")
print(cm)

# Visualize
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Classification Report
print("\nClassification Report:")
print(classification_report(y_test, y_pred))


In [None]:
# Predictions
y_pred = rf_improved.predict(X_test_scaled)

# Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:")
print(cm)

# Visualize
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Classification Report
print("\nClassification Report:")
print(classification_report(y_test, y_pred))


# After training, update the MODEL_ACCURACIES dictionary in backend/main.py
# Use the realistic accuracy from cross-validation, not test accuracy

print(f"\nUpdate MODEL_ACCURACIES in backend/main.py:")
print(f"'random_forest': {realistic_acc:.4f},  # {realistic_acc*100:.2f}% - Realistic CV accuracy")


In [None]:
# After training, update the MODEL_ACCURACIES dictionary in backend/main.py
# Use the realistic accuracy from cross-validation, not test accuracy

print(f"\nUpdate MODEL_ACCURACIES in backend/main.py:")
print(f"'random_forest': {realistic_acc:.4f},  # {realistic_acc*100:.2f}% - Realistic CV accuracy")


## 9. Tips to Prevent Overfitting

1. **Use Cross-Validation**: Always use CV to get realistic accuracy
2. **Regularization**: 
   - Limit tree depth (`max_depth`)
   - Increase `min_samples_split` and `min_samples_leaf`
   - Use `max_features` to limit features per split
3. **More Data**: Collect more training data
4. **Feature Selection**: Remove irrelevant features
5. **Early Stopping**: Stop training when validation accuracy stops improving
6. **Ensemble Methods**: Use ensemble to reduce variance
7. **Dropout** (for neural networks): Add dropout layers
