# 🔐 Crypto-BERT Model Training (Colab Optimized)
## Wi-Fi Vulnerability Detection System - Protocol Analysis Module

**Optimized for Google Colab Free Tier**
- Reduced model size for memory efficiency
- Faster training with smaller dataset
- Mixed precision training for speed
- Memory-efficient data loading

## 📦 Installation and Setup

In [None]:
# Install required packages
!pip install -q tensorflow==2.13.0
!pip install -q scikit-learn
!pip install -q matplotlib seaborn
!pip install -q tqdm

# Setup memory growth and mixed precision
import tensorflow as tf

# Enable memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("🚀 GPU memory growth enabled")
    except RuntimeError as e:
        print(e)

# Enable mixed precision for faster training
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)
print(f"🎯 Mixed precision policy: {policy.name}")

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

In [None]:
# Import required libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.utils import to_categorical

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from tqdm import tqdm
import json
import random
import re
import warnings
import gc
warnings.filterwarnings('ignore')

# Set random seeds
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
random.seed(SEED)

print("✅ All imports successful!")

## ⚙️ Optimized Model Configuration

In [None]:
# Optimized BERT Configuration for Colab
class CryptoBERTConfig:
    def __init__(self):
        # Reduced model architecture for Colab
        self.vocab_size = 15000  # Reduced from 30000
        self.hidden_size = 384   # Reduced from 768
        self.max_sequence_length = 256  # Reduced from 512
        self.num_transformer_layers = 6   # Reduced from 12
        self.num_attention_heads = 6      # Reduced from 12
        self.intermediate_size = 1536     # 4 * hidden_size
        self.dropout_rate = 0.1
        
        # Training parameters optimized for Colab
        self.num_classes = 15
        self.batch_size = 32  # Increased batch size for efficiency
        self.learning_rate = 3e-4  # Slightly higher for faster convergence
        self.epochs = 5  # Reduced epochs
        self.samples_per_class = 800  # Reduced dataset size
        
        # Class labels
        self.class_labels = [
            'STRONG_ENCRYPTION',
            'WEAK_CIPHER_SUITE',
            'CERTIFICATE_INVALID',
            'KEY_REUSE',
            'DOWNGRADE_ATTACK',
            'MAN_IN_MIDDLE',
            'REPLAY_ATTACK',
            'TIMING_ATTACK',
            'QUANTUM_VULNERABLE',
            'ENTROPY_WEAKNESS',
            'HASH_COLLISION',
            'PADDING_ORACLE',
            'LENGTH_EXTENSION',
            'PROTOCOL_CONFUSION',
            'CRYPTO_AGILITY_LACK'
        ]

config = CryptoBERTConfig()
print(f"📋 Optimized Configuration:")
print(f"   Vocab Size: {config.vocab_size:,}")
print(f"   Hidden Size: {config.hidden_size}")
print(f"   Sequence Length: {config.max_sequence_length}")
print(f"   Transformer Layers: {config.num_transformer_layers}")
print(f"   Attention Heads: {config.num_attention_heads}")
print(f"   Batch Size: {config.batch_size}")
print(f"   Samples per Class: {config.samples_per_class}")
print(f"   Total Samples: {config.samples_per_class * config.num_classes:,}")

## 🔄 Optimized Data Generation

In [None]:
class OptimizedProtocolGenerator:
    def __init__(self, config):
        self.config = config
        # Simplified protocol templates
        self.templates = {
            'STRONG_ENCRYPTION': [
                "TLS1.3 AES-256-GCM ECDHE-P256 secure",
                "WPA3 ChaCha20-Poly1305 strong encryption",
                "TLS1.3 AES-128-GCM secure handshake"
            ],
            'WEAK_CIPHER_SUITE': [
                "TLS1.0 RC4 weak cipher detected",
                "WPA DES encryption vulnerable",
                "SSL3.0 MD5 deprecated hash"
            ],
            'CERTIFICATE_INVALID': [
                "TLS cert expired self-signed invalid",
                "SSL certificate chain broken untrusted",
                "HTTPS cert hostname mismatch"
            ],
            'KEY_REUSE': [
                "WPA2 nonce reuse detected session key",
                "TLS key exchange reused RSA",
                "EAP session key material reused"
            ],
            'DOWNGRADE_ATTACK': [
                "TLS1.3 downgrade to TLS1.0 detected",
                "WPA3 downgrade to WPA2 forced",
                "HTTPS downgrade to HTTP"
            ],
            'MAN_IN_MIDDLE': [
                "TLS cert fingerprint mismatch MITM",
                "DNS spoofing certificate invalid",
                "ARP spoofing SSL interception"
            ],
            'REPLAY_ATTACK': [
                "EAP message replay timestamp invalid",
                "TLS handshake replay detected",
                "WPA handshake message replayed"
            ],
            'TIMING_ATTACK': [
                "RSA decrypt timing variation high",
                "AES padding oracle timing leak",
                "HMAC verification timing attack"
            ],
            'QUANTUM_VULNERABLE': [
                "RSA-2048 quantum vulnerable algorithm",
                "ECDSA-P256 post quantum needed",
                "DH key exchange quantum weak"
            ],
            'ENTROPY_WEAKNESS': [
                "PRNG entropy low predictable random",
                "Random generation weak seed",
                "Nonce generation entropy insufficient"
            ],
            'HASH_COLLISION': [
                "MD5 hash collision detected same",
                "SHA1 collision vulnerability found",
                "Hash function collision attack"
            ],
            'PADDING_ORACLE': [
                "AES-CBC padding oracle error leak",
                "PKCS padding oracle attack possible",
                "RSA padding oracle vulnerability"
            ],
            'LENGTH_EXTENSION': [
                "SHA1 HMAC length extension possible",
                "Hash length extension attack detected",
                "MAC length extension vulnerability"
            ],
            'PROTOCOL_CONFUSION': [
                "TLS HTTP protocol confusion detected",
                "SSL strip protocol downgrade",
                "Mixed protocol implementation flaw"
            ],
            'CRYPTO_AGILITY_LACK': [
                "Legacy system crypto agility limited",
                "Single cipher suite support only",
                "No algorithm upgrade path available"
            ]
        }
        
        # Additional variations
        self.variations = [
            "session {}", "timestamp {}", "id {}", "port {}", 
            "version {}", "length {}", "flags {}", "data {}"
        ]
    
    def generate_sequence(self, class_name):
        """Generate a single sequence for given class"""
        base_template = random.choice(self.templates[class_name])
        
        # Add random variations
        if random.random() < 0.3:
            variation = random.choice(self.variations).format(random.randint(1000, 9999))
            base_template += f" {variation}"
        
        return base_template
    
    def generate_dataset(self):
        """Generate complete dataset"""
        sequences = []
        labels = []
        
        print("🔄 Generating optimized dataset...")
        
        for class_idx, class_name in enumerate(tqdm(self.config.class_labels)):
            for _ in range(self.config.samples_per_class):
                sequence = self.generate_sequence(class_name)
                sequences.append(sequence)
                labels.append(class_idx)
        
        return sequences, labels

# Generate optimized dataset
generator = OptimizedProtocolGenerator(config)
sequences, labels = generator.generate_dataset()

print(f"✅ Generated {len(sequences):,} sequences")
print(f"📝 Sample: {sequences[0]}")

# Clear memory
del generator
gc.collect()

## 🔤 Optimized Tokenization

In [None]:
class SimpleTokenizer:
    def __init__(self, vocab_size, max_length):
        self.vocab_size = vocab_size
        self.max_length = max_length
        self.word_to_id = {'[PAD]': 0, '[UNK]': 1, '[CLS]': 2, '[SEP]': 3}
        self.current_id = 4
    
    def build_vocab(self, sequences):
        """Build vocabulary efficiently"""
        print("🔤 Building vocabulary...")
        
        # Count words
        word_freq = {}
        for seq in tqdm(sequences, desc="Counting words"):
            words = seq.lower().split()
            for word in words:
                word_freq[word] = word_freq.get(word, 0) + 1
        
        # Add frequent words
        sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
        
        for word, freq in sorted_words:
            if self.current_id >= self.vocab_size:
                break
            self.word_to_id[word] = self.current_id
            self.current_id += 1
        
        print(f"✅ Vocabulary built: {len(self.word_to_id)} tokens")
    
    def encode_batch(self, sequences):
        """Encode sequences in batches for memory efficiency"""
        encoded = []
        
        for seq in tqdm(sequences, desc="Encoding"):
            words = seq.lower().split()
            
            # Convert to IDs
            ids = [self.word_to_id['[CLS]']]
            for word in words[:self.max_length-2]:
                ids.append(self.word_to_id.get(word, self.word_to_id['[UNK]']))
            ids.append(self.word_to_id['[SEP]'])
            
            # Pad
            while len(ids) < self.max_length:
                ids.append(self.word_to_id['[PAD]'])
            
            encoded.append(ids[:self.max_length])
        
        return np.array(encoded, dtype=np.int32)

# Build tokenizer
tokenizer = SimpleTokenizer(config.vocab_size, config.max_sequence_length)
tokenizer.build_vocab(sequences)

# Encode sequences
X = tokenizer.encode_batch(sequences)
y = np.array(labels, dtype=np.int32)

print(f"📝 Encoded shape: {X.shape}")
print(f"🏷️ Labels shape: {y.shape}")

# Clear intermediate data
del sequences, labels
gc.collect()

## 📊 Data Splitting

In [None]:
# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=SEED
)

X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.2, stratify=y_train, random_state=SEED
)

# Convert to categorical
y_train_cat = to_categorical(y_train, config.num_classes)
y_val_cat = to_categorical(y_val, config.num_classes)
y_test_cat = to_categorical(y_test, config.num_classes)

print(f"📊 Data splits:")
print(f"   Train: {X_train.shape[0]:,}")
print(f"   Val: {X_val.shape[0]:,}")
print(f"   Test: {X_test.shape[0]:,}")

# Clear original data
del X, y
gc.collect()

## 🏗️ Lightweight BERT Model

In [None]:
def scaled_dot_product_attention(q, k, v, mask=None):
    """Efficient attention implementation"""
    matmul_qk = tf.matmul(q, k, transpose_b=True)
    
    # Scale
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
    
    # Apply mask
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)
    
    # Softmax
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
    output = tf.matmul(attention_weights, v)
    
    return output

class MultiHeadAttention(layers.Layer):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        
        assert d_model % self.num_heads == 0
        
        self.depth = d_model // self.num_heads
        
        self.wq = layers.Dense(d_model)
        self.wk = layers.Dense(d_model)
        self.wv = layers.Dense(d_model)
        
        self.dense = layers.Dense(d_model)
    
    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        
        q = self.wq(inputs)
        k = self.wk(inputs)
        v = self.wv(inputs)
        
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        
        scaled_attention = scaled_dot_product_attention(q, k, v)
        
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
        
        concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))
        
        output = self.dense(concat_attention)
        return output

def create_lightweight_bert(config):
    """Create memory-efficient BERT model"""
    
    inputs = layers.Input(shape=(config.max_sequence_length,), dtype=tf.int32)
    
    # Embedding
    embedding = layers.Embedding(
        config.vocab_size, 
        config.hidden_size,
        mask_zero=True
    )(inputs)
    
    # Positional encoding
    position_encoding = layers.Embedding(
        config.max_sequence_length, 
        config.hidden_size
    )(tf.range(config.max_sequence_length))
    
    x = embedding + position_encoding
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.Dropout(config.dropout_rate)(x)
    
    # Transformer blocks
    for i in range(config.num_transformer_layers):
        # Multi-head attention
        attn_output = MultiHeadAttention(config.hidden_size, config.num_attention_heads)(x)
        attn_output = layers.Dropout(config.dropout_rate)(attn_output)
        out1 = layers.LayerNormalization(epsilon=1e-6)(x + attn_output)
        
        # Feed forward
        ffn_output = layers.Dense(config.intermediate_size, activation='gelu')(out1)
        ffn_output = layers.Dense(config.hidden_size)(ffn_output)
        ffn_output = layers.Dropout(config.dropout_rate)(ffn_output)
        x = layers.LayerNormalization(epsilon=1e-6)(out1 + ffn_output)
    
    # Classification head
    cls_token = x[:, 0, :]
    cls_token = layers.Dropout(config.dropout_rate)(cls_token)
    
    # Output layers
    dense = layers.Dense(256, activation='gelu')(cls_token)
    dense = layers.Dropout(config.dropout_rate)(dense)
    
    # Final output with float32 for numerical stability
    outputs = layers.Dense(config.num_classes, activation='softmax', dtype='float32')(dense)
    
    model = Model(inputs, outputs, name='LightweightCryptoBERT')
    
    return model

# Create model
print("🏗️ Creating lightweight BERT model...")
model = create_lightweight_bert(config)

# Model summary
model.summary()

total_params = model.count_params()
model_size_mb = total_params * 4 / (1024 * 1024)

print(f"\n📊 Model Stats:")
print(f"   Parameters: {total_params:,}")
print(f"   Estimated Size: {model_size_mb:.1f} MB")
print(f"   Memory per batch: ~{config.batch_size * config.max_sequence_length * 4 / (1024*1024):.1f} MB")

## 🎯 Optimized Training Setup

In [None]:
# Optimized training configuration
optimizer = keras.optimizers.AdamW(
    learning_rate=config.learning_rate,
    weight_decay=0.01,
    beta_1=0.9,
    beta_2=0.999
)

# Compile with loss scaling for mixed precision
model.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Optimized callbacks
callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=2,
        restore_best_weights=True,
        verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=1,
        min_lr=1e-6,
        verbose=1
    )
]

print(f"🎯 Training config:")
print(f"   Learning Rate: {config.learning_rate}")
print(f"   Batch Size: {config.batch_size}")
print(f"   Epochs: {config.epochs}")
print(f"   Mixed Precision: {tf.keras.mixed_precision.global_policy().name}")

## 🚀 Fast Training

In [None]:
print("🚀 Starting optimized training...")
print(f"📊 Training samples: {X_train.shape[0]:,}")
print(f"🔍 Validation samples: {X_val.shape[0]:,}")

# Train with optimized settings
history = model.fit(
    X_train, y_train_cat,
    batch_size=config.batch_size,
    epochs=config.epochs,
    validation_data=(X_val, y_val_cat),
    callbacks=callbacks,
    verbose=1,
    use_multiprocessing=True,
    workers=2
)

print("✅ Training completed!")

# Clear training data from memory
del X_train, y_train_cat, X_val, y_val_cat
gc.collect()

## 📈 Quick Evaluation

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training')
plt.plot(history.history['val_accuracy'], label='Validation')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training')
plt.plot(history.history['val_loss'], label='Validation')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# Test evaluation
print("🧪 Testing...")
test_loss, test_acc = model.evaluate(X_test, y_test_cat, batch_size=config.batch_size, verbose=0)

print(f"\n📊 Final Results:")
print(f"   Test Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f"   Test Loss: {test_loss:.4f}")
print(f"   Best Val Accuracy: {max(history.history['val_accuracy']):.4f}")

# Quick predictions
y_pred = model.predict(X_test, batch_size=config.batch_size, verbose=0)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true_classes = np.argmax(y_test_cat, axis=1)

print(f"\n🎯 Classification Report (Top 5 classes):")
for i in range(5):
    class_acc = np.mean(y_pred_classes[y_true_classes == i] == i) if np.sum(y_true_classes == i) > 0 else 0
    print(f"   {config.class_labels[i][:20]}: {class_acc:.3f}")

# Memory cleanup
del X_test, y_test_cat, y_pred
gc.collect()

## 💾 Save Optimized Model

In [None]:
# Save model
model_path = 'crypto_bert_optimized.h5'
model.save(model_path)
print(f"💾 Model saved: {model_path}")

# Save tokenizer
tokenizer_data = {
    'word_to_id': tokenizer.word_to_id,
    'vocab_size': tokenizer.vocab_size,
    'max_length': tokenizer.max_length
}

with open('tokenizer_optimized.json', 'w') as f:
    json.dump(tokenizer_data, f)
print("💾 Tokenizer saved: tokenizer_optimized.json")

# Save config
model_config = {
    'vocab_size': config.vocab_size,
    'hidden_size': config.hidden_size,
    'max_sequence_length': config.max_sequence_length,
    'num_transformer_layers': config.num_transformer_layers,
    'num_attention_heads': config.num_attention_heads,
    'num_classes': config.num_classes,
    'class_labels': config.class_labels,
    'test_accuracy': float(test_acc),
    'total_parameters': int(total_params),
    'model_size_mb': float(model_size_mb)
}

with open('config_optimized.json', 'w') as f:
    json.dump(model_config, f, indent=2)
print("💾 Config saved: config_optimized.json")

# Check file size
import os
actual_size = os.path.getsize(model_path) / (1024 * 1024)

print(f"\n📏 Final Model Stats:")
print(f"   File Size: {actual_size:.1f} MB")
print(f"   Parameters: {total_params:,}")
print(f"   Test Accuracy: {test_acc:.4f}")
print(f"   Training Time: Much faster! ⚡")

print(f"\n🎉 Optimized Crypto-BERT Ready!")
print(f"   ✅ Smaller model size")
print(f"   ✅ Faster training")
print(f"   ✅ Memory efficient")
print(f"   ✅ Colab compatible")

## 📥 Download Files

In [None]:
# Download in Colab
try:
    from google.colab import files
    print("📥 Downloading optimized model files...")
    
    files.download('crypto_bert_optimized.h5')
    files.download('tokenizer_optimized.json')
    files.download('config_optimized.json')
    
    print("✅ Download complete!")
except ImportError:
    print("ℹ️ Files saved locally (not in Colab)")

print("\n🎊 All done! Your optimized Crypto-BERT is ready for integration with your other models.")

## 📝 Usage Notes

### Key Optimizations Made:

1. **Reduced Model Size**:
   - Hidden size: 768 → 384
   - Layers: 12 → 6
   - Sequence length: 512 → 256
   - Vocab size: 30K → 15K

2. **Memory Efficiency**:
   - Mixed precision training (float16)
   - Memory growth enabled
   - Batch processing
   - Garbage collection

3. **Training Speed**:
   - Larger batch size (32)
   - Fewer epochs (5)
   - Simplified data generation
   - Efficient tokenization

### Integration with Your System:

```python
# Load the model
model = tf.keras.models.load_model('crypto_bert_optimized.h5')

# Load tokenizer
with open('tokenizer_optimized.json', 'r') as f:
    tokenizer_data = json.load(f)

# Use for predictions in your ensemble
predictions = model.predict(encoded_sequences)
```

This optimized version should train much faster on Colab while still providing good accuracy for your Wi-Fi vulnerability detection system!