### installations

In [3]:
### 1. Install Dependencies
!pip install torch charformer-pytorch

Collecting charformer-pytorch
  Downloading charformer_pytorch-0.0.4-py3-none-any.whl.metadata (655 bytes)
Downloading charformer_pytorch-0.0.4-py3-none-any.whl (4.8 kB)
Installing collected packages: charformer-pytorch
Successfully installed charformer-pytorch-0.0.4


In [42]:
!pip install libinjection-python

Collecting libinjection-python
  Downloading libinjection-python-1.1.6.tar.gz (174 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/174.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m174.0/174.0 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: libinjection-python
  Building wheel for libinjection-python (pyproject.toml) ... [?25l[?25hdone
  Created wheel for libinjection-python: filename=libinjection_python-1.1.6-cp312-cp312-linux_x86_64.whl size=251706 sha256=b4ebc3bba7a9dc748bb1b8bc92b5ad841afca5e57e7b24ddeb693b3ce48d39a7
  Stored in directory: /root/.cache/pip/wheels/b5/a2/0f/eb48da355b19a32635f793215e5d5908b072f7dc951e9fe295
Successfully built libinjection-python
Installing collected packa

### injection model detector

In [11]:
import torch
import torch.nn as nn
from typing import Dict, List, Optional
import warnings
from charformer_pytorch import GBST
warnings.filterwarnings('ignore')

In [6]:
class InjectionDetectionModel(nn.Module):
    """
    Enhanced model with better regularization
    """
    def __init__(self,
                 num_tokens: int = 257,
                 dim: int = 128,
                 max_block_size: int = 4,
                 score_consensus_attn: bool = True,
                 d_model: int = 128,
                 nhead: int = 1,
                 dim_feedforward: int = 256,
                 num_layers: int = 1,
                 max_length: int = 2048,
                 downsample_factor: int = 4,
                 mlp_hidden_dims: List[int] = [256, 128],
                 dropout: float = 0.2,  # INCREASED from 0.1 to 0.2
                 attack_type: str = "unknown"):

        super(InjectionDetectionModel, self).__init__()

        self.attack_type = attack_type
        self.max_length = max_length
        self.downsample_factor = downsample_factor
        self.d_model = d_model
        self.gbst_norm = nn.LayerNorm(d_model)

        # GBST architecture (unchanged)
        self.gbst = GBST(
            num_tokens=num_tokens,
            dim=dim,
            max_block_size=max_block_size,
            score_consensus_attn=score_consensus_attn
        )

        # Transformer with increased dropout
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,  # Now uses the increased dropout
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Enhanced MLP with Batch Normalization
        mlp_layers = []
        input_dim = d_model

        for hidden_dim in mlp_hidden_dims:
            mlp_layers.extend([
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim),  # ADDED: Batch normalization
                nn.Dropout(dropout)
            ])
            input_dim = hidden_dim

        mlp_layers.append(nn.Linear(input_dim, 1))
        self.mlp = nn.Sequential(*mlp_layers)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """Forward pass - unchanged logic"""
        batch_size = x.size(0)

        # GBST Tokenization
        gbst_output = self.gbst(x)
        if isinstance(gbst_output, tuple):
            gbst_output, gbst_mask = gbst_output

        gbst_output = self.gbst_norm(gbst_output)

        # Transformer Encoder
        if attention_mask is not None:
            attention_mask = attention_mask[:, ::self.downsample_factor]

        encoder_output = self.transformer_encoder(
            gbst_output,
            src_key_padding_mask=attention_mask if attention_mask is not None else None
        )

        # Classification with attention masking
        if attention_mask is not None:
            mask_expanded = (~attention_mask).unsqueeze(-1).expand_as(encoder_output)
            encoder_output = encoder_output * mask_expanded.float()
            pooled = encoder_output.sum(dim=1) / mask_expanded.sum(dim=1)
        else:
            pooled = encoder_output.mean(dim=1)

        logits = self.mlp(pooled)
        probabilities = self.sigmoid(logits)

        return {
            'logits': logits.squeeze(-1),
            'probabilities': probabilities.squeeze(-1),
            'encoder_output': encoder_output,
            'attention_weights': None
        }


In [7]:
class MultiClassInjectionModel(nn.Module):
    """
    Transfer learning model for multi-class attack classification
    Replaces binary classification head with multi-class head
    """

    def __init__(self, base_model: nn.Module, num_classes: int = 4, hidden_dim: int = 128):
        super(MultiClassInjectionModel, self).__init__()

        # Keep original model components
        self.gbst = base_model.gbst
        self.gbst_norm = base_model.gbst_norm
        self.transformer_encoder = base_model.transformer_encoder
        self.d_model = base_model.d_model
        self.downsample_factor = base_model.downsample_factor

        # Replace classification head with multi-class head
        self.multi_class_head = nn.Sequential(
            nn.Linear(self.d_model, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes)
        )

        print(f"Multi-class model created with {num_classes} classes")

    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        batch_size = x.size(0)

        # GBST encoding
        gbst_output = self.gbst(x)
        if isinstance(gbst_output, tuple):
            gbst_output, _ = gbst_output
        gbst_output = self.gbst_norm(gbst_output)

        # Transformer encoding
        if attention_mask is not None:
            attention_mask = attention_mask[:, ::self.downsample_factor]

        encoder_output = self.transformer_encoder(
            gbst_output,
            src_key_padding_mask=attention_mask if attention_mask is not None else None
        )

        # Pooling
        if attention_mask is not None:
            mask_expanded = (~attention_mask).unsqueeze(-1).expand_as(encoder_output)
            encoder_output = encoder_output * mask_expanded.float()
            pooled = encoder_output.sum(dim=1) / mask_expanded.sum(dim=1)
        else:
            pooled = encoder_output.mean(dim=1)

        # Multi-class classification
        logits = self.multi_class_head(pooled)
        probabilities = torch.softmax(logits, dim=-1)

        return {
            'logits': logits,
            'probabilities': probabilities,
            'encoder_output': encoder_output
        }


In [8]:
"""
Production Injection Detection System
Provides both binary detection and attack type classification
"""


class InjectionDetector:
    """
    Production-ready injection detection system

    Usage:
        detector = InjectionDetector(
            binary_model_path='models/binary_model.pth',
            multiclass_model_path='models/multiclass_model.pth'
        )

        result = detector.detect('SELECT * FROM users')
    """

    def __init__(self,
                 binary_model_path: str,
                 multiclass_model_path: str,
                 device: str = None,
                 max_length: int = 2048):
        """
        Initialize detector with trained models

        Args:
            binary_model_path: Path to binary classification model (.pth)
            multiclass_model_path: Path to multi-class model (.pth)
            device: 'cuda' or 'cpu'. Auto-detects if None
            max_length: Maximum input length (default: 2048)
        """

        if device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device

        self.max_length = max_length

        # Load binary model
        print(f"Loading binary model from {binary_model_path}...")
        #from models import InjectionDetectionModel  # Your model definition
        self.binary_model = InjectionDetectionModel()
        self.binary_model.load_state_dict(torch.load(binary_model_path, map_location=self.device))
        self.binary_model.to(self.device)
        self.binary_model.eval()

        # Load multi-class model
        print(f"Loading multi-class model from {multiclass_model_path}...")
        #from models import MultiClassInjectionModel
        base_model = InjectionDetectionModel()
        self.multiclass_model = MultiClassInjectionModel(base_model, num_classes=4)
        self.multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location=self.device))
        self.multiclass_model.to(self.device)
        self.multiclass_model.eval()

        self.attack_types = ['sqli', 'commandi', 'xss', 'traversal']

        print(f"Models loaded successfully on {self.device}")

    def _preprocess(self, text: str) -> Dict[str, torch.Tensor]:
        """Convert text to model input format"""

        # Convert to bytes
        byte_sequence = list(text.encode('utf-8'))

        if len(byte_sequence) > self.max_length:
            byte_sequence = byte_sequence[:self.max_length]

        attention_mask = [False] * len(byte_sequence)

        while len(byte_sequence) < self.max_length:
            byte_sequence.append(0)
            attention_mask.append(True)

        return {
            'input_ids': torch.tensor([byte_sequence], dtype=torch.long).to(self.device),
            'attention_mask': torch.tensor([attention_mask], dtype=torch.bool).to(self.device)
        }

    def detect(self, text: str, threshold: float = 0.5) -> Dict:
        """
        Detect if input contains injection attack

        Args:
            text: Input text to analyze
            threshold: Detection threshold (default: 0.5)

        Returns:
            Dict with detection results:
            {
                'is_malicious': bool,
                'confidence': float,
                'attack_type': str or None,
                'attack_confidence': float or None
            }
        """

        if not text or not text.strip():
            return {
                'is_malicious': False,
                'confidence': 1.0,
                'attack_type': None,
                'attack_confidence': None
            }

        # Preprocess
        inputs = self._preprocess(text)

        with torch.no_grad():
            # Binary classification
            binary_output = self.binary_model(inputs['input_ids'], inputs['attention_mask'])
            malicious_prob = binary_output['probabilities'].item()
            is_malicious = malicious_prob > threshold

            # If malicious, classify attack type
            attack_type = None
            attack_confidence = None

            if is_malicious:
                multiclass_output = self.multiclass_model(inputs['input_ids'], inputs['attention_mask'])
                probs = multiclass_output['probabilities'][0].cpu().numpy()
                attack_idx = probs.argmax()
                attack_type = self.attack_types[attack_idx]
                attack_confidence = float(probs[attack_idx])

        return {
            'is_malicious': is_malicious,
            'confidence': float(malicious_prob),
            'attack_type': attack_type,
            'attack_confidence': attack_confidence
        }

    def batch_detect(self, texts: List[str], threshold: float = 0.5) -> List[Dict]:
        """
        Batch detection for multiple texts

        Args:
            texts: List of texts to analyze
            threshold: Detection threshold

        Returns:
            List of detection results
        """
        return [self.detect(text, threshold) for text in texts]


In [56]:
# Initialize once (expensive operation)
detector = InjectionDetector(
    binary_model_path='/content/best_binary_injection_model.pth',
    multiclass_model_path='/content/best_multiclass_model.pth',
    device='cpu'  # 'cuda' or 'cpu'
)
start_time = time.time()
# Single prediction
text = "value=%27test%27"
result = detector.detect(text, threshold=0.5)
end_time = time.time()
avg_time = (end_time - start_time) / 1000 * 1000  # Convert to milliseconds
print(f"Average processing time: {avg_time:.3f} ms per request")
# Result format:
# {
#     'is_malicious': True/False,
#     'confidence': 0.0-1.0,
#     'attack_type': 'sqli'/'xss'/'commandi'/'traversal' or None,
#     'attack_confidence': 0.0-1.0 or None
# }

# Batch prediction
texts = ["text1", "text2", "text3"]
results = detector.batch_detect(texts)

Loading binary model from /content/best_binary_injection_model.pth...
Loading multi-class model from /content/best_multiclass_model.pth...
Multi-class model created with 4 classes
Models loaded successfully on cpu
Average processing time: 0.069 ms per request


In [57]:
result

{'is_malicious': True,
 'confidence': 0.7882174253463745,
 'attack_type': 'xss',
 'attack_confidence': 0.9999486207962036}

In [49]:
results

[{'is_malicious': False,
  'confidence': 0.12550747394561768,
  'attack_type': None,
  'attack_confidence': None},
 {'is_malicious': False,
  'confidence': 0.13454963266849518,
  'attack_type': None,
  'attack_confidence': None},
 {'is_malicious': False,
  'confidence': 0.13089822232723236,
  'attack_type': None,
  'attack_confidence': None}]

### Filtering pipeline

In [44]:
"""
Rule-based injection detection filter using libinjection and custom patterns
Returns: 0 (benign), 1 (attack), 2 (ambiguous - needs model)
"""

import re
import string
from typing import Union

# Install required library: pip install libinjection-python
try:
    import libinjection
    LIBINJECTION_AVAILABLE = True
except ImportError:
    print("Warning: libinjection-python not installed. Install with: pip install libinjection-python")
    LIBINJECTION_AVAILABLE = False

class RuleBasedInjectionFilter:
    """
    Fast rule-based filter for web injection detection
    Uses libinjection for definitive attack detection and custom patterns for benign detection
    """

    def __init__(self):
        # Compile regex patterns for performance
        self._compile_patterns()

    def _compile_patterns(self):
        """Compile all regex patterns for better performance"""

        # Obviously benign patterns (return 0)
        self.benign_patterns = [
            # Pure alphanumeric
            re.compile(r'^[a-zA-Z0-9_\-\.@\s]*$'),

            # Common usernames/emails
            re.compile(r'^[a-zA-Z0-9][a-zA-Z0-9_\-\.]{1,63}@[a-zA-Z0-9][a-zA-Z0-9\-\.]+\.[a-zA-Z]{2,}$'),

            # Simple numeric values
            re.compile(r'^\d+(\.\d+)?$'),

            # Common search terms (letters, numbers, spaces, basic punctuation)
            re.compile(r'^[a-zA-Z0-9\s\'\"\,\.\!\?\-]+$'),

            # URL-safe strings
            re.compile(r'^[a-zA-Z0-9\-\._~:/?#[\]@!$&\'()*+,;=]*$'),
        ]

        # Obviously malicious patterns (return 1) - High confidence
        self.malicious_patterns = [
            # SQL injection keywords (high confidence)
            re.compile(r'\b(union\s+select|drop\s+table|delete\s+from|insert\s+into)\b', re.IGNORECASE),

            # XSS script tags
            re.compile(r'<script[^>]*>.*?</script>', re.IGNORECASE | re.DOTALL),
            re.compile(r'<iframe[^>]*>.*?</iframe>', re.IGNORECASE | re.DOTALL),

            # JavaScript execution
            re.compile(r'javascript\s*:', re.IGNORECASE),
            re.compile(r'on(load|error|click|mouseover)\s*=', re.IGNORECASE),

            # Command injection
            re.compile(r'\b(cmd|system|exec|eval|passthru)\s*\(', re.IGNORECASE),

            # Path traversal
            re.compile(r'\.\.[\\/]\.\.[\\/]'),

            # SQL comment patterns
            re.compile(r'--\s*$', re.MULTILINE),
            re.compile(r'/\*.*?\*/', re.DOTALL),
        ]

        # Suspicious patterns that need model evaluation (return 2)
        self.suspicious_patterns = [
            # SQL-like patterns (lower confidence)
            re.compile(r'\b(select|from|where|order\s+by|group\s+by)\b', re.IGNORECASE),

            # HTML-like patterns
            re.compile(r'<[^>]+>'),

            # Encoding patterns
            re.compile(r'%[0-9a-f]{2}', re.IGNORECASE),
            re.compile(r'&#\d+;'),

            # Function call patterns
            re.compile(r'\w+\s*\([^)]*\)'),

            # Suspicious characters clustering
            re.compile(r'[<>"\'\(\);=%&\|]{3,}'),
        ]

    def is_obviously_benign(self, text: str) -> bool:
        """Check if text matches obviously benign patterns"""

        # Empty or very short strings are usually benign
        if not text or len(text.strip()) <= 2:
            return True

        # Check basic characteristics
        if self._is_simple_alphanumeric(text):
            return True

        # Check against benign patterns
        for pattern in self.benign_patterns:
            if pattern.match(text.strip()):
                return True

        return False

    def _is_simple_alphanumeric(self, text: str) -> bool:
        """Check if text is simple alphanumeric with minimal special chars"""

        # Count character types
        alpha_count = sum(1 for c in text if c.isalpha())
        digit_count = sum(1 for c in text if c.isdigit())
        space_count = sum(1 for c in text if c.isspace())
        safe_special = sum(1 for c in text if c in '_-.@')
        other_count = len(text) - alpha_count - digit_count - space_count - safe_special

        # If mostly alphanumeric with minimal special characters
        total_safe = alpha_count + digit_count + space_count + safe_special
        return len(text) > 0 and (total_safe / len(text)) >= 0.95

    def has_definitive_attack(self, text: str) -> bool:
        """Check for definitive attack patterns using libinjection and high-confidence patterns"""

        # Use libinjection for SQL injection detection
        if LIBINJECTION_AVAILABLE:
            try:
                # Check SQL injection
                sqli_result = libinjection.is_sql_injection(text)
                if sqli_result.get('is_sqli', False):
                    return True

                # Check XSS
                xss_result = libinjection.is_xss(text)
                if xss_result.get('is_xss', False):
                    return True
            except:
                pass  # Fall back to pattern matching if libinjection fails

        # Check high-confidence malicious patterns
        for pattern in self.malicious_patterns:
            if pattern.search(text):
                return True

        return False

    def has_suspicious_patterns(self, text: str) -> bool:
        """Check for suspicious patterns that need model evaluation"""

        for pattern in self.suspicious_patterns:
            if pattern.search(text):
                return True

        return False

    def classify_text(self, text: str) -> int:
        """
        Main classification function

        Returns:
            0: Obviously benign (safe to pass)
            1: Obviously malicious (block immediately)
            2: Ambiguous (needs model evaluation)
        """

        if not isinstance(text, str):
            return 2  # Non-string input needs evaluation

        # Basic preprocessing
        text = text.strip()

        # Empty strings are benign
        if not text:
            return 0

        # Check for obvious attacks first (highest priority)
        if self.has_definitive_attack(text):
            return 1

        # Check for obviously benign patterns
        if self.is_obviously_benign(text):
            return 0

        # Check for suspicious patterns
        if self.has_suspicious_patterns(text):
            return 2

        # If no patterns match but not obviously benign, be cautious
        # Check text characteristics
        if self._needs_model_evaluation(text):
            return 2

        # Default to benign for simple text
        return 0

    def _needs_model_evaluation(self, text: str) -> bool:
        """Determine if text characteristics suggest model evaluation is needed"""

        # Very long strings might hide attacks
        if len(text) > 1000:
            return True

        # High ratio of special characters
        special_chars = sum(1 for c in text if not c.isalnum() and not c.isspace())
        if len(text) > 0 and (special_chars / len(text)) > 0.3:
            return True

        # Contains multiple encoding types
        has_url_encoding = '%' in text
        has_html_encoding = '&' in text and ';' in text
        has_unicode = any(ord(c) > 127 for c in text)

        encoding_count = sum([has_url_encoding, has_html_encoding, has_unicode])
        if encoding_count >= 2:
            return True

        # Contains mixed quotes and brackets (potential injection)
        quote_bracket_chars = sum(1 for c in text if c in '\'"()[]{}')
        if quote_bracket_chars >= 4:
            return True

        return False

# Convenience function for direct use
def quick_injection_check(text: str) -> int:
    """
    Quick injection detection function

    Args:
        text: Input text to check

    Returns:
        0: Obviously benign (safe to pass)
        1: Obviously malicious (block immediately)
        2: Ambiguous (needs model evaluation)
    """
    filter_instance = RuleBasedInjectionFilter()
    return filter_instance.classify_text(text)



In [45]:
# Example usage and testing
if __name__ == "__main__":

    # Initialize filter
    injection_filter = RuleBasedInjectionFilter()

    # Test cases
    test_cases = [
        # Obviously benign (should return 0)
        ("john_smith", 0),
        ("user@example.com", 0),
        ("hello world tout le monde! j'espere que vous allez bien ? (bien)", 0),
        ("12345", 0),
        ("search term here", 0),

        # Obviously malicious (should return 1)
        ("'; DROP TABLE users; --", 1),
        ("UNION SELECT password FROM users", 1),
        ("<script>alert('xss')</script>", 1),
        ("javascript:alert(1)", 1),
        ("../../etc/passwd", 1),

        # Ambiguous (should return 2)
        ("SELECT name FROM products WHERE id=1", 2),
        ("<div>content</div>", 2),
        ("value=%27test%27", 2),
        ("function(param)", 2),
    ]

    print("Testing Rule-Based Injection Filter")
    print("="*50)

    for text, expected in test_cases:
        result = injection_filter.classify_text(text)
        status = "✓" if result == expected else "✗"

        result_map = {0: "BENIGN", 1: "ATTACK", 2: "AMBIGUOUS"}
        print(f"{status} '{text}' -> {result} ({result_map[result]})")

    print("\nLibinjection available:", LIBINJECTION_AVAILABLE)

    # Performance test
    if LIBINJECTION_AVAILABLE:
        import time

        test_text = "normal user input text here"
        start_time = time.time()

        for _ in range(1000):
            injection_filter.classify_text(test_text)

        end_time = time.time()
        avg_time = (end_time - start_time) / 1000 * 1000  # Convert to milliseconds

        print(f"\nAverage processing time: {avg_time:.3f} ms per request")

Testing Rule-Based Injection Filter
✓ 'john_smith' -> 0 (BENIGN)
✓ 'user@example.com' -> 0 (BENIGN)
✓ 'hello world tout le monde! j'espere que vous allez bien ? (bien)' -> 0 (BENIGN)
✓ '12345' -> 0 (BENIGN)
✓ 'search term here' -> 0 (BENIGN)
✓ ''; DROP TABLE users; --' -> 1 (ATTACK)
✓ 'UNION SELECT password FROM users' -> 1 (ATTACK)
✓ '<script>alert('xss')</script>' -> 1 (ATTACK)
✓ 'javascript:alert(1)' -> 1 (ATTACK)
✓ '../../etc/passwd' -> 1 (ATTACK)
✗ 'SELECT name FROM products WHERE id=1' -> 1 (ATTACK)
✓ '<div>content</div>' -> 2 (AMBIGUOUS)
✓ 'value=%27test%27' -> 2 (AMBIGUOUS)
✗ 'function(param)' -> 0 (BENIGN)

Libinjection available: True

Average processing time: 0.017 ms per request
