# Project 11: Network-based Ransomware Detection

## Objective
Build a machine learning model that can identify network traffic patterns associated with ransomware activity, distinguishing them from normal, benign traffic.

## Dataset
We use the CIC-IDS2017 dataset from Kaggle, which contains real network traffic including WannaCry ransomware samples.

## Key Features
- Handling extreme class imbalance (ransomware events are rare)
- RandomForestClassifier with balanced class weights
- Focus on recall to minimize missed attacks
- Network forensics through feature importance analysis

## 1. Environment Setup and Data Loading

In [None]:
# Install required packages
!pip install kaggle pandas numpy scikit-learn matplotlib seaborn

import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set up plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

In [None]:
# Setup Kaggle API (requires kaggle.json in ~/.kaggle/)
if not os.path.exists(os.path.expanduser('~/.kaggle/kaggle.json')):
    print("Please set up your Kaggle API credentials first.")
    print("1. Go to https://www.kaggle.com/account")
    print("2. Create API token and download kaggle.json")
    print("3. Place it in ~/.kaggle/ directory")
else:
    print("Kaggle API configured. Downloading dataset...")
    !kaggle datasets download -d cicdataset/cicids2017 --unzip

## 2. Data Loading and Preprocessing

In [None]:
# Load the datasets
print("Loading network traffic data...")

# Benign traffic from Monday working hours
benign_path = 'MachineLearningCVE/Monday-WorkingHours.pcap_ISCX.csv'
# Ransomware traffic from Friday morning (includes WannaCry)
ransomware_path = 'MachineLearningCVE/Friday-WorkingHours-Morning.pcap_ISCX.csv'

try:
    df_benign = pd.read_csv(benign_path, encoding='utf-8', low_memory=False)
    df_ransomware = pd.read_csv(ransomware_path, encoding='utf-8', low_memory=False)
    print(f"Benign traffic shape: {df_benign.shape}")
    print(f"Ransomware traffic shape: {df_ransomware.shape}")
except FileNotFoundError:
    print("Dataset files not found. Please ensure the dataset is downloaded correctly.")
    print("Expected files:")
    print("- MachineLearningCVE/Monday-WorkingHours.pcap_ISCX.csv")
    print("- MachineLearningCVE/Friday-WorkingHours-Morning.pcap_ISCX.csv")

In [None]:
# Prepare labels
df_benign['Label'] = 'Benign'
# Friday file contains multiple attack types - we'll focus on ransomware vs benign
df_ransomware['Label'] = df_ransomware[' Label'].apply(
    lambda x: 'Benign' if x.strip() == 'BENIGN' else 'Ransomware'
)

# Combine datasets
df = pd.concat([df_benign, df_ransomware], ignore_index=True)
print(f"Combined dataset shape: {df.shape}")

# Clean column names
df.columns = df.columns.str.strip()
print("Column names cleaned.")

In [None]:
# Data cleaning
print("Cleaning data...")
print(f"Shape before cleaning: {df.shape}")

# Handle infinite values
df.replace([np.inf, -np.inf], np.nan, inplace=True)

# Remove rows with NaN values
df.dropna(inplace=True)
print(f"Shape after removing NaN/infinite values: {df.shape}")

# Encode labels: Ransomware -> 1, Benign -> 0
df['Label'] = df['Label'].apply(lambda x: 1 if x == 'Ransomware' else 0)

# Check class distribution
print("\nClass Distribution:")
class_counts = df['Label'].value_counts()
print(f"Benign (0): {class_counts[0]:,}")
print(f"Ransomware (1): {class_counts[1]:,}")
print(f"Imbalance ratio: {class_counts[0]/class_counts[1]:.1f}:1")

## 3. Handling Class Imbalance

In [None]:
# Address extreme class imbalance through downsampling
print("Addressing class imbalance...")

df_majority = df[df['Label'] == 0]  # Benign
df_minority = df[df['Label'] == 1]  # Ransomware

# Downsample majority class to 5x the minority class size
df_majority_downsampled = df_majority.sample(n=len(df_minority)*5, random_state=42)
df_balanced = pd.concat([df_majority_downsampled, df_minority])

print("\nClass distribution after downsampling:")
balanced_counts = df_balanced['Label'].value_counts()
print(f"Benign (0): {balanced_counts[0]:,}")
print(f"Ransomware (1): {balanced_counts[1]:,}")
print(f"New ratio: {balanced_counts[0]/balanced_counts[1]:.1f}:1")

# Use balanced dataset for training
df = df_balanced

## 4. Feature Preparation and Data Splitting

In [None]:
# Prepare features and target
X = df.drop(columns=['Label'])
y = df['Label']

print(f"Feature matrix shape: {X.shape}")
print(f"Target vector shape: {y.shape}")

# Stratified split to maintain class ratio
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

print(f"\nTraining set: {X_train.shape[0]:,} samples")
print(f"Test set: {X_test.shape[0]:,} samples")
print(f"Training class distribution:")
print(y_train.value_counts().sort_index())

## 5. Model Training

In [None]:
# Initialize RandomForest with balanced class weights
print("Training RandomForest Classifier...")

model = RandomForestClassifier(
    n_estimators=100,
    random_state=42,
    class_weight='balanced',  # Critical for handling remaining imbalance
    n_jobs=-1,
    max_depth=15,  # Prevent overfitting
    min_samples_split=10,
    min_samples_leaf=5
)

# Fit the model
model.fit(X_train, y_train)
print("Training completed successfully!")

## 6. Model Evaluation

In [None]:
# Make predictions
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)[:, 1]

# Classification report
print("Classification Report:")
print("=" * 50)
print(classification_report(y_test, y_pred, target_names=['Benign (0)', 'Ransomware (1)']))

# Confusion matrix
print("\nConfusion Matrix Analysis:")
cm = confusion_matrix(y_test, y_pred)
print(f"True Negatives (Correctly identified benign): {cm[0,0]}")
print(f"False Positives (Benign flagged as ransomware): {cm[0,1]}")
print(f"False Negatives (Missed ransomware - CRITICAL): {cm[1,0]}")
print(f"True Positives (Correctly identified ransomware): {cm[1,1]}")

In [None]:
# Visualize confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Reds', 
            xticklabels=['Benign', 'Ransomware'], 
            yticklabels=['Benign', 'Ransomware'])
plt.title('Confusion Matrix - Network Ransomware Detection')
plt.ylabel('Actual Label')
plt.xlabel('Predicted Label')
plt.show()

print("⚠️  False Negatives (bottom-left) are the most dangerous metric!")
print("   These represent real ransomware attacks that were missed.")

## 7. Feature Importance Analysis

In [None]:
# Extract feature importances
feature_importance = pd.DataFrame({
    'feature': X.columns,
    'importance': model.feature_importances_
}).sort_values('importance', ascending=False)

# Display top 15 most important features
print("Top 15 Most Important Features for Ransomware Detection:")
print("=" * 60)
for i, (feature, importance) in enumerate(feature_importance.head(15).values):
    print(f"{i+1:2d}. {feature:<35} {importance:.4f}")

In [None]:
# Visualize feature importance
plt.figure(figsize=(12, 8))
top_15_features = feature_importance.head(15)

plt.barh(range(len(top_15_features)), top_15_features['importance'].values, color='darkred')
plt.yticks(range(len(top_15_features)), top_15_features['feature'].values)
plt.xlabel('Feature Importance')
plt.title('Top 15 Network Features for Ransomware Detection')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

print("\n🔍 Network Forensics Insights:")
print("• Flow timing features (IAT) are highly predictive")
print("• Packet flags (URG, PSH) reveal ransomware communication patterns")
print("• Flow duration and idle times show scanning behavior")
print("• These patterns enable early detection before file encryption")

## 8. Model Performance Analysis

In [None]:
# Performance metrics calculation
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
auc = roc_auc_score(y_test, y_pred_proba)

print("🎯 Model Performance Summary:")
print("=" * 40)
print(f"Accuracy:  {accuracy:.3f} ({accuracy*100:.1f}%)")
print(f"Precision: {precision:.3f} ({precision*100:.1f}%)")
print(f"Recall:    {recall:.3f} ({recall*100:.1f}%) ← Most Critical")
print(f"F1-Score:  {f1:.3f}")
print(f"AUC-ROC:   {auc:.3f}")

print(f"\n📊 Business Impact:")
print(f"• Catches {recall*100:.1f}% of ransomware attacks")
print(f"• {precision*100:.1f}% of alerts are true positives")
print(f"• Missed attacks: {cm[1,0]} out of {cm[1,0] + cm[1,1]} total")

## 9. Conclusion and Next Steps

In [None]:
print("🛡️  PROJECT CONCLUSION: Network-based Ransomware Detection")
print("=" * 70)

print("\n✅ Key Achievements:")
print(f"  • Developed ML model with {recall*100:.1f}% recall for ransomware detection")
print(f"  • Successfully handled extreme class imbalance (original ratio >1000:1)")
print(f"  • Identified key network indicators: flow timing, packet flags, idle patterns")
print(f"  • Created interpretable model for security analyst investigations")

print("\n🎯 Security Value:")
print("  • Early warning system before file encryption occurs")
print("  • Network-level detection complements endpoint solutions")
print("  • Forensic insights for threat hunting and incident response")
print("  • Foundation for automated quarantine and response workflows")

print("\n🚀 Next Steps for Production:")
print("  1. Deploy in network monitoring infrastructure")
print("  2. Integrate with SIEM for automated alerting")
print("  3. Implement real-time stream processing")
print("  4. Add model retraining pipeline for new ransomware variants")
print("  5. Develop threat intelligence integration")

print("\n⚡ Technical Recommendations:")
print("  • Monitor model performance for concept drift")
print("  • Implement ensemble methods for improved robustness")
print("  • Add temporal analysis for sequential attack patterns")
print("  • Consider deep learning for complex attack variants")