In [2]:
# Standard library imports
import logging   # For logging messages and debugging information
import datetime  # For tracking execution time and timestamps
from datetime import datetime

# Third-party imports for data processing and machine learning
import torch                                           # PyTorch deep learning framework
from torch import nn                                   # Neural network modules
from torch.utils.data import Dataset, DataLoader       # Data handling utilities
import pandas as pd                                    # For DataFrame operations
import numpy as np                                     # For numerical operations
from sklearn.model_selection import train_test_split   # For splitting dataset
from sklearn.preprocessing import MultiLabelBinarizer  # For label encoding
import matplotlib.pyplot as plt                        # For plotting learning curves
from tqdm import tqdm                                  # For progress bars

# Hugging Face transformers imports
from transformers import BertModel, BertTokenizer, AdamW  # BERT model and utilities

In [3]:
# Configure logging to show timestamp, log level, and message
logging.basicConfig(
    level=logging.INFO,  # Show all info messages and above (info, warning, error, critical)
    format='%(asctime)s - %(levelname)s - %(message)s',  # Format: timestamp - level - message
    datefmt='%Y-%m-%d %H:%M:%S'  # Date format for the timestamp
)

# Log the start of the program and verify logging is working
logging.info("Starting program and initializing imports...")

2025-01-10 08:34:00 - INFO - Starting program and initializing imports...


In [4]:
class ShellAttackDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Log dataset statistics
        logging.info(f"Created dataset with {len(texts)} samples")
        logging.info(f"Number of labels: {labels.shape[1]}")
        
        # Calculate and log average sequence length
        avg_len = np.mean([len(" ".join(text).split()) for text in texts])
        logging.info(f"Average sequence length (words): {avg_len:.2f}")
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = " ".join(self.texts[idx])
        
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.FloatTensor(self.labels[idx])
        }

In [5]:
class BertClassifier(nn.Module):
    def __init__(self, num_labels):
        super(BertClassifier, self).__init__()
        # Load pre-trained BERT
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        # Add custom classification head
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.sigmoid = nn.Sigmoid()
        
        # Log model architecture details
        logging.info(f"Initialized BERT Classifier with:")
        logging.info(f"- BERT hidden size: {self.bert.config.hidden_size}")
        logging.info(f"- Number of labels: {num_labels}")
        logging.info(f"- Dropout rate: 0.1")
    
    def forward(self, input_ids, attention_mask):
        # Get BERT outputs
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # Use the [CLS] token representation
        pooled_output = outputs.pooler_output
        
        # Apply dropout and classification
        x = self.dropout(pooled_output)
        x = self.classifier(x)
        return self.sigmoid(x)

In [6]:
def train_model(model, train_loader, val_loader, device, num_epochs=10):
    # Initialize optimizer
    optimizer = AdamW(model.parameters(), lr=2e-5)
    
    # Binary Cross Entropy loss for multi-label classification
    criterion = nn.BCELoss()
    
    # Store metrics
    train_losses = []
    val_losses = []
    
    # Calculate total number of training steps
    total_steps = len(train_loader) * num_epochs
    logging.info(f"Starting training with {total_steps} total steps")
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        train_steps = 0
        
        epoch_start_time = datetime.now()
        
        for batch in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            optimizer.zero_grad()
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            train_steps += 1
        
        avg_train_loss = total_train_loss / train_steps
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        val_steps = 0
        
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = criterion(outputs, labels)
                
                total_val_loss += loss.item()
                val_steps += 1
        
        avg_val_loss = total_val_loss / val_steps
        val_losses.append(avg_val_loss)
        
        epoch_time = datetime.now() - epoch_start_time
        
        logging.info(f"\nEpoch {epoch + 1} Summary:")
        logging.info(f"Time taken: {epoch_time}")
        logging.info(f"Average training loss: {avg_train_loss:.4f}")
        logging.info(f"Average validation loss: {avg_val_loss:.4f}")
        
        # Calculate and log loss improvement
        if epoch > 0:
            train_improvement = train_losses[-2] - train_losses[-1]
            val_improvement = val_losses[-2] - val_losses[-1]
            logging.info(f"Training loss improvement: {train_improvement:.4f}")
            logging.info(f"Validation loss improvement: {val_improvement:.4f}")
    
    return train_losses, val_losses

In [None]:
def save_training_results(model, training_history, mlb, save_path='training_results.pt'):
    """Save model, training history and label binarizer."""
    torch.save({
        'model_state_dict': model.state_dict(),
        'training_history': training_history,
        'label_binarizer_classes': mlb.classes_
    }, save_path)
    logging.info(f"Training results saved to {save_path}")

def load_training_results(model_class, save_path='training_results.pt'):
    """Load model and training history."""
    checkpoint = torch.load(save_path)
    
    # Recreate model with same number of labels
    model = model_class(num_labels=len(checkpoint['label_binarizer_classes']))
    model.load_state_dict(checkpoint['model_state_dict'])
    
    return model, checkpoint['training_history'], checkpoint['label_binarizer_classes']

In [9]:
# Start timing
start_time = datetime.now()
logging.info("Starting data processing and model training pipeline")
    
# Load data
logging.info("Loading dataset from parquet file...")
df = pd.read_parquet("../data/processed/ssh_attacks_decoded.parquet")
logging.info(f"Loaded dataset with {len(df)} rows")
    
# Log data statistics
logging.info("\nDataset Statistics:")
logging.info(f"Number of unique session IDs: {df['session_id'].nunique()}")
logging.info(f"Date range: {df['first_timestamp'].min()} to {df['first_timestamp'].max()}")

2025-01-10 08:34:25 - INFO - Starting data processing and model training pipeline
2025-01-10 08:34:25 - INFO - Loading dataset from parquet file...
2025-01-10 08:34:27 - INFO - Loaded dataset with 233035 rows
2025-01-10 08:34:27 - INFO - 
Dataset Statistics:
2025-01-10 08:34:27 - INFO - Number of unique session IDs: 233035
2025-01-10 08:34:27 - INFO - Date range: 2019-06-04 09:45:11.151186+00:00 to 2020-02-29 23:59:22.199490+00:00


In [10]:
# Process labels
logging.info("\nProcessing labels...")
mlb = MultiLabelBinarizer()
labels = mlb.fit_transform([set(x) for x in df['Set_Fingerprint']])
logging.info(f"Number of unique labels: {len(mlb.classes_)}")
logging.info("Most common labels:")
label_counts = pd.Series([label for labels_list in df['Set_Fingerprint'] for label in labels_list]).value_counts()
for label, count in label_counts.head().items():
    logging.info(f"- {label}: {count} occurrences")

2025-01-10 08:34:31 - INFO - 
Processing labels...
2025-01-10 08:34:32 - INFO - Number of unique labels: 7
2025-01-10 08:34:32 - INFO - Most common labels:
2025-01-10 08:34:32 - INFO - - Discovery: 232145 occurrences
2025-01-10 08:34:32 - INFO - - Persistence: 211295 occurrences
2025-01-10 08:34:32 - INFO - - Execution: 92927 occurrences
2025-01-10 08:34:32 - INFO - - Defense Evasion: 18999 occurrences
2025-01-10 08:34:32 - INFO - - Harmless: 2206 occurrences


In [11]:
# Split data
logging.info("\nSplitting data into train and validation sets...")
X_train, X_val, y_train, y_val = train_test_split(
    df['full_session'].values,
    labels,
    test_size=0.2,
    random_state=42
)
logging.info(f"Training set size: {len(X_train)}")
logging.info(f"Validation set size: {len(X_val)}")

2025-01-10 08:34:34 - INFO - 
Splitting data into train and validation sets...
2025-01-10 08:34:34 - INFO - Training set size: 186428
2025-01-10 08:34:34 - INFO - Validation set size: 46607


In [12]:
# Initialize tokenizer
logging.info("\nInitializing BERT tokenizer...")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

2025-01-10 08:34:35 - INFO - 
Initializing BERT tokenizer...


In [13]:
# Create datasets
logging.info("Creating datasets...")
train_dataset = ShellAttackDataset(X_train, y_train, tokenizer)
val_dataset = ShellAttackDataset(X_val, y_val, tokenizer)
    
# Create data loaders
batch_size = 16
logging.info(f"\nCreating data loaders with batch size {batch_size}...")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
logging.info(f"Number of training batches: {len(train_loader)}")
logging.info(f"Number of validation batches: {len(val_loader)}")

2025-01-10 08:34:35 - INFO - Creating datasets...
2025-01-10 08:34:35 - INFO - Created dataset with 186428 samples
2025-01-10 08:34:35 - INFO - Number of labels: 7
2025-01-10 08:34:37 - INFO - Average sequence length (words): 69.70
2025-01-10 08:34:37 - INFO - Created dataset with 46607 samples
2025-01-10 08:34:37 - INFO - Number of labels: 7
2025-01-10 08:34:37 - INFO - Average sequence length (words): 69.59
2025-01-10 08:34:37 - INFO - 
Creating data loaders with batch size 16...
2025-01-10 08:34:37 - INFO - Number of training batches: 11652
2025-01-10 08:34:37 - INFO - Number of validation batches: 2913


In [14]:
# Initialize model
logging.info("\nInitializing model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f"Using device: {device}")
model = BertClassifier(num_labels=len(mlb.classes_))
model.to(device)
    
# Count model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logging.info(f"Total parameters: {total_params:,}")
logging.info(f"Trainable parameters: {trainable_params:,}")

2025-01-10 08:34:37 - INFO - 
Initializing model...
2025-01-10 08:34:37 - INFO - Using device: cpu
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
2025-01-10 08:34:43 - INFO - Initialized BERT Classifier with:
2025-01-10 08:34:43 - INFO 

In [None]:
# Train the model
logging.info("\nStarting model training...")
train_losses, val_losses = train_model(model, train_loader, val_loader, device, epochs=4)

2025-01-10 08:34:43 - INFO - 
Starting model training...
2025-01-10 08:34:43 - INFO - Starting training with 116520 total steps
Epoch 1/10:   0%|          | 1/11652 [01:32<300:29:17, 92.85s/it]

In [None]:
# After training:
training_history = {
    'train_losses': train_losses,
    'val_losses': val_losses,
    'true_labels': y_true,
    'predicted_probabilities': y_pred_probs,
    'probabilities_per_epoch': y_pred_probs_per_epoch
}

save_training_results(model, training_history, mlb)

# Later, to load:
# model, history, class_names = load_training_results(BertClassifier)

In [None]:
def plot_learning_curves(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Learning Curves')
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
# Plot learning curves
logging.info("Plotting learning curves...")
plot_learning_curves(train_losses, val_losses)

In [None]:
# Log total execution time
total_time = datetime.now() - start_time
logging.info(f"\nTotal execution time: {total_time}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, precision_recall_curve, auc, f1_score
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

In [None]:
def plot_metrics_over_epochs(y_true, y_pred_probs, epochs):
    """Plot TPR, FPR, TNR, FNR over epochs."""
    metrics = {
        'TPR': [], 'FPR': [], 'TNR': [], 'FNR': []
    }
    
    for epoch_preds in y_pred_probs:
        epoch_pred_binary = (epoch_preds > 0.5).astype(int)
        tp = np.sum((y_true == 1) & (epoch_pred_binary == 1), axis=0)
        fp = np.sum((y_true == 0) & (epoch_pred_binary == 1), axis=0)
        tn = np.sum((y_true == 0) & (epoch_pred_binary == 0), axis=0)
        fn = np.sum((y_true == 1) & (epoch_pred_binary == 0), axis=0)
        
        metrics['TPR'].append(tp / (tp + fn))
        metrics['FPR'].append(fp / (fp + tn))
        metrics['TNR'].append(tn / (tn + fp))
        metrics['FNR'].append(fn / (fn + tp))
    
    plt.figure(figsize=(12, 8))
    for metric, values in metrics.items():
        plt.plot(range(1, epochs + 1), np.mean(values, axis=1), label=metric)
    
    plt.xlabel('Epoch')
    plt.ylabel('Rate')
    plt.title('Metrics Over Epochs')
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_loss_curves(train_losses, val_losses):
    """Plot training and validation loss curves."""
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_roc_curves(y_true, y_pred_probs, class_names):
    """Plot ROC curves for each class."""
    plt.figure(figsize=(15, 10))
    for i in range(y_true.shape[1]):
        fpr, tpr, _ = roc_curve(y_true[:, i], y_pred_probs[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {roc_auc:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves for Each Class')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.show()

def plot_pr_curves(y_true, y_pred_probs, class_names):
    """Plot Precision-Recall curves for each class."""
    plt.figure(figsize=(15, 10))
    for i in range(y_true.shape[1]):
        precision, recall, _ = precision_recall_curve(y_true[:, i], y_pred_probs[:, i])
        plt.plot(recall, precision, label=class_names[i])
    
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curves for Each Class')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.show()

def plot_prob_histograms(y_pred_probs, class_names):
    """Plot histograms of prediction probabilities for each class."""
    n_classes = len(class_names)
    n_cols = 3
    n_rows = (n_classes + n_cols - 1) // n_cols
    
    plt.figure(figsize=(15, 5 * n_rows))
    for i in range(n_classes):
        plt.subplot(n_rows, n_cols, i + 1)
        plt.hist(y_pred_probs[:, i], bins=50)
        plt.title(f'Class: {class_names[i]}')
        plt.xlabel('Prediction Probability')
        plt.ylabel('Count')
    plt.tight_layout()
    plt.show()

def plot_3d_roc(y_true, y_pred_probs, class_names, n_classes=3):
    """Plot 3D ROC curve for the first three classes."""
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    for i in range(min(n_classes, len(class_names))):
        fpr, tpr, _ = roc_curve(y_true[:, i], y_pred_probs[:, i])
        ax.plot(fpr, tpr, zs=i, zdir='y', label=class_names[i])
    
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('Class')
    ax.set_zlabel('True Positive Rate')
    ax.set_title('3D ROC Curve')
    plt.legend()
    plt.show()

def plot_f1_scores(y_true, y_pred_probs, class_names):
    """Plot class-wise F1 scores."""
    y_pred = (y_pred_probs > 0.5).astype(int)
    f1_scores = [f1_score(y_true[:, i], y_pred[:, i]) for i in range(len(class_names))]
    
    plt.figure(figsize=(12, 6))
    bars = plt.bar(class_names, f1_scores)
    plt.xticks(rotation=45, ha='right')
    plt.xlabel('Class')
    plt.ylabel('F1 Score')
    plt.title('Class-wise F1 Scores')
    
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2f}',
                ha='center', va='bottom')
    plt.tight_layout()
    plt.show()

def plot_performance_metrics(y_true, y_pred_probs, class_names):
    """Plot precision, recall, and F1 score for each class."""
    y_pred = (y_pred_probs > 0.5).astype(int)
    metrics = {
        'Precision': [],
        'Recall': [],
        'F1': []
    }
    
    for i in range(len(class_names)):
        tp = np.sum((y_true[:, i] == 1) & (y_pred[:, i] == 1))
        fp = np.sum((y_true[:, i] == 0) & (y_pred[:, i] == 1))
        fn = np.sum((y_true[:, i] == 1) & (y_pred[:, i] == 0))
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        metrics['Precision'].append(precision)
        metrics['Recall'].append(recall)
        metrics['F1'].append(f1)
    
    x = np.arange(len(class_names))
    width = 0.25
    
    fig, ax = plt.subplots(figsize=(15, 8))
    ax.bar(x - width, metrics['Precision'], width, label='Precision')
    ax.bar(x, metrics['Recall'], width, label='Recall')
    ax.bar(x + width, metrics['F1'], width, label='F1')
    
    ax.set_ylabel('Score')
    ax.set_title('Model Performance Metrics by Class')
    ax.set_xticks(x)
    ax.set_xticklabels(class_names, rotation=45, ha='right')
    ax.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# Import the plotting functions
from plotting_utils import *

# Assuming you have these variables from your model training:
# y_true: Ground truth labels
# y_pred_probs: Model's predicted probabilities
# train_losses: List of training losses per epoch
# val_losses: List of validation losses per epoch
# class_names: List of your label names
# y_pred_probs_per_epoch: List of prediction probabilities for each epoch

# Plot all metrics
def plot_all_metrics(y_true, y_pred_probs, y_pred_probs_per_epoch, train_losses, val_losses, class_names):
    # 1. Metrics over epochs
    plot_metrics_over_epochs(y_true, y_pred_probs_per_epoch, len(train_losses))
    
    # 2. Loss curves
    plot_loss_curves(train_losses, val_losses)
    
    # 3. ROC curves
    plot_roc_curves(y_true, y_pred_probs, class_names)
    
    # 4. Precision-Recall curves
    plot_pr_curves(y_true, y_pred_probs, class_names)
    
    # 5. Probability histograms
    plot_prob_histograms(y_pred_probs, class_names)
    
    # 6. 3D ROC curve
    plot_3d_roc(y_true, y_pred_probs, class_names)
    
    # 7. F1 scores
    plot_f1_scores(y_true, y_pred_probs, class_names)
    
    # 8. Performance metrics
    plot_performance_metrics(y_true, y_pred_probs, class_names)

# Example usage:
"""
# After training your model:
y_true = model_evaluation['true_labels']
y_pred_probs = model_evaluation['predicted_probabilities']
y_pred_probs_per_epoch = model_evaluation['probabilities_per_epoch']
train_losses = model_evaluation['train_losses']
val_losses = model_evaluation['val_losses']
class_names = mlb.classes_  # from your MultiLabelBinarizer

# Generate all plots
plot_all_metrics(y_true, y_pred_probs, y_pred_probs_per_epoch, 
                train_losses, val_losses, class_names)

# Or call individual functions as needed:
plot_roc_curves(y_true, y_pred_probs, class_names)
plot_loss_curves(train_losses, val_losses)
"""