# üõ°Ô∏è SSH Brute Force Detection - Minimal LSTM Pipeline

A minimal end-to-end LSTM training pipeline for SSH brute-force detection.

**Workflow:**
1. Install & Setup
2. Upload SSH.log
3. Parse + Build Sequences
4. Train LSTM
5. Evaluate
6. Export ssh_lstm.joblib


In [None]:
# =============================================================================
# CELL 1: INSTALL & SETUP
# =============================================================================
%pip install -q pandas numpy python-dateutil tensorflow joblib

import re
import warnings
import numpy as np
import pandas as pd
from datetime import datetime
from dateutil import parser as dateparser
from typing import Optional

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
import joblib

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout, Bidirectional
from tensorflow.keras.callbacks import EarlyStopping

warnings.filterwarnings('ignore')
tf.get_logger().setLevel('ERROR')

# Reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

# === CONFIGURATION ===
WINDOW_SIZE = 20        # Events per window
STRIDE = 5              # Sliding window stride
TIME_WINDOW_SEC = 120   # 2 minutes for labeling
FAIL_THRESHOLD = 10     # Min suspicious events for attack label

# LSTM params
EMBEDDING_DIM = 32
LSTM_UNITS = 64
DROPOUT = 0.3
EPOCHS = 15
BATCH_SIZE = 256
PATIENCE = 5

print(f"TensorFlow: {tf.__version__}")
print(f"Config: WINDOW_SIZE={WINDOW_SIZE}, STRIDE={STRIDE}, FAIL_THRESHOLD={FAIL_THRESHOLD}")
print("‚úÖ Setup complete")


In [None]:
# =============================================================================
# CELL 2: UPLOAD SSH.log
# =============================================================================
from google.colab import files

print("üìÅ Please upload your SSH.log file:")
uploaded = files.upload()
log_path = list(uploaded.keys())[0]
print(f"‚úÖ Loaded: {log_path} ({len(uploaded[log_path]):,} bytes)")


In [None]:
# =============================================================================
# CELL 3: PARSE + BUILD SEQUENCES
# =============================================================================

# Event taxonomy
EVENT_TYPES = [
    'PAD', 'FAILED_PASSWORD', 'INVALID_USER', 'ACCEPTED_PASSWORD',
    'ACCEPTED_PUBLICKEY', 'DISCONNECT', 'REVERSE_DNS_FAIL',
    'PAM_AUTH_FAILURE', 'CONNECTION_CLOSED', 'SESSION_OPENED',
    'SESSION_CLOSED', 'OTHER'
]
token2id = {t: i for i, t in enumerate(EVENT_TYPES)}
VOCAB_SIZE = len(token2id)

# Suspicious events for attack labeling
SUSPICIOUS_EVENTS = {'FAILED_PASSWORD', 'INVALID_USER', 'PAM_AUTH_FAILURE', 'REVERSE_DNS_FAIL'}

# IP extraction patterns
IP_PATTERNS = [
    re.compile(r'from\s+((?:\d{1,3}\.){3}\d{1,3})'),
    re.compile(r'by\s+((?:\d{1,3}\.){3}\d{1,3})'),
    re.compile(r'rhost=((?:\d{1,3}\.){3}\d{1,3})'),
    re.compile(r'\[((?:\d{1,3}\.){3}\d{1,3})\]'),
]

def extract_ip(line: str) -> Optional[str]:
    for pattern in IP_PATTERNS:
        m = pattern.search(line)
        if m:
            return m.group(1)
    return None

def classify_event(line: str) -> str:
    lower = line.lower()
    if 'failed password' in lower: return 'FAILED_PASSWORD'
    if 'invalid user' in lower and 'failed' not in lower: return 'INVALID_USER'
    if 'accepted publickey' in lower: return 'ACCEPTED_PUBLICKEY'
    if 'accepted password' in lower: return 'ACCEPTED_PASSWORD'
    if 'authentication failure' in lower or ('pam_unix' in lower and 'auth' in lower): return 'PAM_AUTH_FAILURE'
    if 'possible break-in' in lower or 'reverse mapping' in lower: return 'REVERSE_DNS_FAIL'
    if 'disconnect' in lower: return 'DISCONNECT'
    if 'connection closed' in lower: return 'CONNECTION_CLOSED'
    if 'session opened' in lower: return 'SESSION_OPENED'
    if 'session closed' in lower: return 'SESSION_CLOSED'
    return 'OTHER'

def infer_year(lines):
    current_year = datetime.now().year
    for line in lines[:50]:
        parts = line.split()
        if len(parts) >= 3:
            try:
                ts = dateparser.parse(f"{parts[0]} {parts[1]} {parts[2]} {current_year}")
                if ts and ts > datetime.now():
                    return current_year - 1
            except: pass
    return current_year

# === PARSE LOG FILE ===
print("üîç Parsing log file...")
with open(log_path, 'r', encoding='utf-8', errors='ignore') as f:
    raw_lines = f.readlines()

year = infer_year(raw_lines)
print(f"üìÖ Using year: {year}")

events = []
for line in raw_lines:
    line = line.strip()
    if not line: continue
    parts = line.split()
    if len(parts) < 4: continue
    try:
        timestamp = dateparser.parse(f"{parts[0]} {parts[1]} {parts[2]} {year}")
    except: continue
    ip = extract_ip(line)
    if not ip: continue
    events.append({'timestamp': timestamp, 'src_ip': ip, 'event_type': classify_event(line)})

df_events = pd.DataFrame(events).sort_values(['src_ip', 'timestamp']).reset_index(drop=True)
print(f"‚úÖ Parsed {len(df_events):,} events from {df_events['src_ip'].nunique():,} IPs")
print(f"\nüìà Event distribution:\n{df_events['event_type'].value_counts()}")

# === BUILD SEQUENCES & LABELS ===
print(f"\nüîÑ Building sequences (WINDOW={WINDOW_SIZE}, STRIDE={STRIDE})...")

sequences, labels = [], []
for ip, group in df_events.groupby('src_ip'):
    group = group.reset_index(drop=True)
    n = len(group)
    if n < WINDOW_SIZE: continue
    
    for start in range(0, n - WINDOW_SIZE + 1, STRIDE):
        window = group.iloc[start:start + WINDOW_SIZE]
        seq = [token2id[e] for e in window['event_type']]
        
        # Weak labeling
        time_span = (window['timestamp'].max() - window['timestamp'].min()).total_seconds()
        suspicious_count = window['event_type'].isin(SUSPICIOUS_EVENTS).sum()
        label = 1 if (time_span <= TIME_WINDOW_SEC and suspicious_count >= FAIL_THRESHOLD) else 0
        
        sequences.append(seq)
        labels.append(label)

X = np.array(sequences)
y = np.array(labels)
n_attack = y.sum()
n_benign = (y == 0).sum()

print(f"\n‚úÖ Built {len(X):,} sequences")
print(f"   Attack: {n_attack:,} ({n_attack/len(y)*100:.1f}%)")
print(f"   Benign: {n_benign:,} ({n_benign/len(y)*100:.1f}%)")
if n_attack < 50:
    print(f"‚ö†Ô∏è Only {n_attack} attack sequences. Consider lowering FAIL_THRESHOLD to 6.")


In [None]:
# =============================================================================
# CELL 4: TRAIN LSTM
# =============================================================================

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=RANDOM_SEED, stratify=y if n_attack >= 10 else None
)
print(f"Train: {len(X_train):,}, Test: {len(X_test):,}")

# Build model
model = Sequential([
    Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, input_length=WINDOW_SIZE),
    Bidirectional(LSTM(LSTM_UNITS, return_sequences=False)),
    Dropout(DROPOUT),
    Dense(32, activation='relu'),
    Dropout(DROPOUT / 2),
    Dense(1, activation='sigmoid')
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
print("\nüß† Model architecture:")
model.summary()

# Class weights for imbalance
class_weight = {0: 1.0, 1: n_benign / n_attack} if n_attack > 0 and n_benign > 0 else None
if class_weight:
    print(f"\n‚öñÔ∏è Class weights: {class_weight}")

# Train
print(f"\nüöÄ Training for up to {EPOCHS} epochs...")
history = model.fit(
    X_train, y_train,
    validation_split=0.1,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    class_weight=class_weight,
    callbacks=[EarlyStopping(monitor='val_loss', patience=PATIENCE, restore_best_weights=True)],
    verbose=1
)
print("\n‚úÖ Training complete!")


In [None]:
# =============================================================================
# CELL 5: EVALUATE
# =============================================================================

# Predict on test set
y_prob = model.predict(X_test, verbose=0).flatten()

# Find optimal threshold
best_threshold = 0.5
best_f1 = 0
for thresh in np.arange(0.1, 0.9, 0.05):
    y_pred_temp = (y_prob >= thresh).astype(int)
    _, _, f1_temp, _ = precision_recall_fscore_support(y_test, y_pred_temp, average='binary', zero_division=0)
    if f1_temp > best_f1:
        best_f1 = f1_temp
        best_threshold = thresh

FINAL_THRESHOLD = best_threshold
y_pred = (y_prob >= FINAL_THRESHOLD).astype(int)

# Metrics
precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='binary', zero_division=0)
cm = confusion_matrix(y_test, y_pred)

print("="*60)
print("üìä EVALUATION RESULTS")
print("="*60)
print(f"\nOptimal Threshold: {FINAL_THRESHOLD:.2f}")
print(f"\nMetrics:")
print(f"   Precision: {precision:.4f}")
print(f"   Recall:    {recall:.4f}")
print(f"   F1 Score:  {f1:.4f}")

print(f"\nConfusion Matrix:")
print(f"                  Predicted")
print(f"                  Benign  Attack")
if cm.size == 4:
    tn, fp, fn, tp = cm.ravel()
else:
    tn, fp, fn, tp = cm[0,0], 0, 0, 0
print(f"Actual Benign     {tn:6}  {fp:6}")
print(f"Actual Attack     {fn:6}  {tp:6}")
print("="*60)


In [None]:
# =============================================================================
# CELL 6: EXPORT ssh_lstm.joblib
# =============================================================================
from google.colab import files

# Bundle everything needed to reload
export_bundle = {
    'model_json': model.to_json(),
    'weights': model.get_weights(),
    'token2id': token2id,
    'window_size': WINDOW_SIZE,
    'stride': STRIDE,
    'fail_threshold': FAIL_THRESHOLD,
    'time_window_sec': TIME_WINDOW_SEC,
    'threshold': float(FINAL_THRESHOLD)
}

output_file = 'ssh_lstm.joblib'
joblib.dump(export_bundle, output_file)
print(f"‚úÖ Saved: {output_file}")

# Download
files.download(output_file)
print("\nüì• Download started!")


In [None]:
# =============================================================================
# CELL 7: SUMMARY
# =============================================================================
print("="*60)
print("üìã PIPELINE SUMMARY")
print("="*60)
print(f"\nüìä Data:")
print(f"   Parsed events:     {len(df_events):,}")
print(f"   Unique IPs:        {df_events['src_ip'].nunique():,}")
print(f"   Total sequences:   {len(X):,}")
print(f"\nüè∑Ô∏è Labels:")
print(f"   Attack (1):  {n_attack:,} ({n_attack/len(y)*100:.1f}%)")
print(f"   Benign (0):  {n_benign:,} ({n_benign/len(y)*100:.1f}%)")
print(f"\nüìà Model Performance:")
print(f"   Precision: {precision:.4f}")
print(f"   Recall:    {recall:.4f}")
print(f"   F1 Score:  {f1:.4f}")
print(f"   Threshold: {FINAL_THRESHOLD:.2f}")
print(f"\nüì¶ Exported:")
print(f"   File: {output_file}")
print("="*60)
print("\n‚úÖ Pipeline complete!")
