# Project 15: Vulnerability Prediction in Network Devices

**Objective:** To build an interpretable model that can predict whether a network device is vulnerable based on its software version string.

**Dataset Source:** **Synthetically Generated**. We will create a dataset of network devices with various software versions and programmatically label some as 'vulnerable' based on predefined rules (e.g., older versions are more likely to be vulnerable).

**Model:** We will use a **Decision Tree Classifier**. This model is an excellent choice because its internal logic is highly transparent. We will be able to visualize the exact rules (e.g., "IF major_version < 15.2 AND IF device_type is 'Firewall', THEN predict 'vulnerable'") the model has learned, making it perfect for explaining decisions to a security team.

**Instructions:**
This notebook is fully self-contained and does not require any external files or APIs. Simply run the entire code block in Google Colab.

## 1. Import Necessary Libraries

In [None]:
import pandas as pd
import numpy as np
import re
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

## 2. Synthetic Data Generation

In [None]:
print("--- Generating Synthetic Network Device Dataset ---")

# Define some device types and software version patterns
devices = {
    'CISCO_ROUTER': ['15.1(4)M', '15.2(1)T', '15.5(3)S', '16.1.1', '16.3.2'],
    'JUNIPER_FIREWALL': ['18.4R1', '19.2R2', '20.1R1', '20.4R3', '21.2R1'],
    'ARISTA_SWITCH': ['4.20.6M', '4.21.5F', '4.22.1F', '4.23.0F', '4.25.1M']
}

# Generate a list of device records
data = []
for device, versions in devices.items():
    for version in versions:
        # Generate 100 devices for each version
        for _ in range(100):
            data.append([device, version])

df = pd.DataFrame(data, columns=['device_type', 'software_version'])

# --- Define Vulnerability Rules ---
# This is our ground truth. In the real world, this would come from vulnerability scanners.
def is_vulnerable(row):
    if 'CISCO' in row['device_type'] and ('15.1' in row['software_version'] or '15.2' in row['software_version']):
        return 1
    if 'JUNIPER' in row['device_type'] and ('18.4' in row['software_version'] or '19.2' in row['software_version']):
        return 1
    if 'ARISTA' in row['device_type'] and '4.20' in row['software_version']:
        return 1
    return 0 # Not vulnerable

df['is_vulnerable'] = df.apply(is_vulnerable, axis=1)
# Add some random noise to make it more realistic
noise = np.random.choice([0, 1], size=len(df), p=[0.95, 0.05])
df['is_vulnerable'] = df['is_vulnerable'] | noise

print("Dataset generation complete. Sample:")
print(df.sample(5))
print("\nClass Distribution:")
print(df['is_vulnerable'].value_counts())

## 3. Feature Engineering from Version Strings

In [None]:
print("\n--- Engineering Numerical Features from Version Strings ---")

# This function uses regex to parse complex version strings into numerical features.
def parse_version(version):
    # Match patterns like 15.1(4)M -> [15, 1, 4] or 20.4R3 -> [20, 4, 3]
    parts = re.findall(r'(\d+)', version)
    parts = [int(p) for p in parts]
    # Ensure all feature vectors have the same length
    while len(parts) < 3:
        parts.append(0)
    return parts[:3]

# Apply the parsing function
version_features = df['software_version'].apply(parse_version)
df[['v_major', 'v_minor', 'v_patch']] = pd.DataFrame(version_features.tolist(), index=df.index)

print("Feature engineering complete. Sample with new features:")
print(df.sample(5))

## 4. Data Splitting and Encoding

In [None]:
print("\n--- Splitting and Encoding Data ---")

# One-hot encode the 'device_type'
df_encoded = pd.get_dummies(df, columns=['device_type'], drop_first=True)

feature_cols = [col for col in df_encoded.columns if col not in ['software_version', 'is_vulnerable']]
X = df_encoded[feature_cols]
y = df_encoded['is_vulnerable']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
print(f"X_train shape: {X_train.shape}, X_test shape: {X_test.shape}")

## 5. Model Training

In [None]:
print("\n--- Model Training ---")
model = DecisionTreeClassifier(random_state=42, max_depth=4) # Limit depth to keep the tree interpretable

print("Training the Decision Tree model...")
model.fit(X_train, y_train)
print("Training complete.")

## 6. Model Evaluation

In [None]:
print("\n--- Model Evaluation ---")
y_pred = model.predict(X_test)

print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['Not Vulnerable (0)', 'Vulnerable (1)']))

print("\nConfusion Matrix:")
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Purples', xticklabels=['Not Vulnerable', 'Vulnerable'], yticklabels=['Not Vulnerable', 'Vulnerable'])
plt.title('Confusion Matrix')
plt.ylabel('Actual Label')
plt.xlabel('Predicted Label')
plt.show()

## 7. Model Interpretability: Visualizing the Decision Tree

In [None]:
print("\n--- Model Interpretability: The Learned Rules ---")

plt.figure(figsize=(20, 12))
plot_tree(
    model,
    feature_names=X.columns,
    class_names=['Not Vulnerable', 'Vulnerable'],
    filled=True,
    rounded=True,
    fontsize=10
)
plt.title("Decision Tree for Vulnerability Prediction", fontsize=16)
plt.show()

## 8. Conclusion

In [None]:
print("\n--- Conclusion ---")
print("The Decision Tree model successfully learned to predict device vulnerabilities from software versions.")
print("Key Takeaways:")
print("- The model achieved high accuracy, precision, and recall, demonstrating its effectiveness in identifying at-risk devices.")
print("- The most powerful output is the decision tree visualization itself. A network security engineer can look at the tree and understand the exact logic the model is using. For example, the top node might split on 'v_major <= 15', immediately isolating a group of older, high-risk devices.")
print("- This approach moves security from a reactive to a proactive stance. Instead of waiting for a vulnerability scanner to run, this model could be integrated with an inventory system (like NetBox or ServiceNow) to provide a continuously updated risk score for every device on the network.")
print("- This allows for intelligent, data-driven patch prioritization, focusing limited maintenance windows on the devices that pose the greatest risk.")