# Markov Chain Analysis for Malicious Prompt Detection

This notebook implements a Markov Chain-based approach to identify word sequences responsible for the maliciousness of prompts.

## Approach:
1. Extract k-grams (bigrams and trigrams) from prompts
2. Build separate Markov Chains for malicious and benign prompts
3. Calculate transition probabilities for word sequences
4. Use likelihood ratios to classify prompts
5. Identify most discriminative sequences

In [None]:
# Install required packages (for Colab)
import sys

# Check if running on Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running on Google Colab")
else:
    print("Running locally")

## 1. Import Libraries and Load Dataset

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')

# Set style for better visualizations
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

In [None]:
# Load dataset
if IN_COLAB:
    # Download from GitHub
    url = 'https://raw.githubusercontent.com/Meet2304/Project-Vigil/main/Dataset/MPDD.csv'
    df = pd.read_csv(url)
else:
    # Load from local path
    df = pd.read_csv('../Dataset/MPDD.csv')

print(f"Dataset shape: {df.shape}")
print(f"\nFirst few rows:")
df.head()

In [None]:
# Dataset statistics
print("Dataset Statistics:")
print(f"Total prompts: {len(df)}")
print(f"Malicious prompts: {df['isMalicious'].sum()} ({df['isMalicious'].sum()/len(df)*100:.1f}%)")
print(f"Benign prompts: {(1-df['isMalicious']).sum()} ({(1-df['isMalicious']).sum()/len(df)*100:.1f}%)")
print(f"\nMissing values:\n{df.isnull().sum()}")

## 2. Text Preprocessing

In [None]:
def preprocess_text(text):
    """
    Preprocess text for k-gram extraction:
    - Convert to lowercase
    - Remove special characters (keep only alphanumeric and spaces)
    - Remove extra whitespaces
    - Tokenize into words
    """
    if pd.isna(text):
        return []
    
    # Convert to lowercase
    text = str(text).lower()
    
    # Remove special characters but keep spaces
    text = re.sub(r'[^a-z0-9\s]', ' ', text)
    
    # Remove extra whitespaces
    text = ' '.join(text.split())
    
    # Tokenize
    tokens = text.split()
    
    return tokens

# Apply preprocessing
df['tokens'] = df['Prompt'].apply(preprocess_text)
df['token_count'] = df['tokens'].apply(len)

print("Preprocessing complete!")
print(f"Average tokens per prompt: {df['token_count'].mean():.1f}")
print(f"\nExample preprocessed prompt:")
print(f"Original: {df['Prompt'].iloc[0]}")
print(f"Tokens: {df['tokens'].iloc[0]}")

## 3. K-Gram Extraction

In [None]:
def extract_ngrams(tokens, n):
    """
    Extract n-grams from a list of tokens.
    
    Args:
        tokens: List of tokens
        n: Size of n-grams
    
    Returns:
        List of n-grams (tuples)
    """
    if len(tokens) < n:
        return []
    
    ngrams = []
    for i in range(len(tokens) - n + 1):
        ngrams.append(tuple(tokens[i:i+n]))
    
    return ngrams

# Extract bigrams and trigrams
print("Extracting k-grams...")
df['bigrams'] = df['tokens'].apply(lambda x: extract_ngrams(x, 2))
df['trigrams'] = df['tokens'].apply(lambda x: extract_ngrams(x, 3))

print("K-gram extraction complete!")
print(f"\nExample bigrams: {df['bigrams'].iloc[0][:5]}")
print(f"Example trigrams: {df['trigrams'].iloc[0][:5]}")

## 4. Markov Chain Model

We'll build separate Markov chains for malicious and benign prompts. Each chain will store:
- Transition probabilities: P(word_i | word_{i-1})
- Sequence probabilities for classification

In [None]:
class MarkovChain:
    """
    A simple Markov Chain model for text analysis.
    Stores transition probabilities between words.
    """
    
    def __init__(self, smoothing=1.0):
        """
        Initialize the Markov Chain.
        
        Args:
            smoothing: Laplace smoothing parameter (default=1.0)
        """
        self.transitions = defaultdict(lambda: defaultdict(int))
        self.word_counts = defaultdict(int)
        self.smoothing = smoothing
        self.vocabulary = set()
        
    def train(self, token_lists):
        """
        Train the Markov Chain on a list of token sequences.
        
        Args:
            token_lists: List of token lists
        """
        for tokens in token_lists:
            # Add START and END tokens
            tokens = ['<START>'] + tokens + ['<END>']
            
            # Build transitions
            for i in range(len(tokens) - 1):
                current_word = tokens[i]
                next_word = tokens[i + 1]
                
                self.transitions[current_word][next_word] += 1
                self.word_counts[current_word] += 1
                self.vocabulary.add(current_word)
                self.vocabulary.add(next_word)
    
    def get_probability(self, current_word, next_word):
        """
        Get the transition probability P(next_word | current_word).
        Uses Laplace smoothing for unseen transitions.
        
        Args:
            current_word: Current word
            next_word: Next word
        
        Returns:
            Transition probability
        """
        vocab_size = len(self.vocabulary)
        
        # Laplace smoothing
        numerator = self.transitions[current_word][next_word] + self.smoothing
        denominator = self.word_counts[current_word] + (self.smoothing * vocab_size)
        
        if denominator == 0:
            return 1.0 / vocab_size
        
        return numerator / denominator
    
    def get_sequence_log_probability(self, tokens):
        """
        Calculate the log probability of a token sequence.
        
        Args:
            tokens: List of tokens
        
        Returns:
            Log probability of the sequence
        """
        if len(tokens) == 0:
            return 0.0
        
        tokens = ['<START>'] + tokens + ['<END>']
        log_prob = 0.0
        
        for i in range(len(tokens) - 1):
            prob = self.get_probability(tokens[i], tokens[i + 1])
            log_prob += np.log(prob + 1e-10)  # Add small value to avoid log(0)
        
        return log_prob
    
    def get_most_likely_transitions(self, top_n=20):
        """
        Get the most frequent transitions.
        
        Args:
            top_n: Number of top transitions to return
        
        Returns:
            List of (current_word, next_word, count) tuples
        """
        all_transitions = []
        
        for current_word, next_words in self.transitions.items():
            if current_word in ['<START>', '<END>']:
                continue
            for next_word, count in next_words.items():
                if next_word in ['<START>', '<END>']:
                    continue
                all_transitions.append((current_word, next_word, count))
        
        # Sort by count
        all_transitions.sort(key=lambda x: x[2], reverse=True)
        
        return all_transitions[:top_n]

print("Markov Chain class defined!")

## 5. Train/Test Split

In [None]:
# Split data
X = df['tokens'].values
y = df['isMalicious'].values

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training set size: {len(X_train)}")
print(f"Test set size: {len(X_test)}")
print(f"\nTraining set distribution:")
print(f"  Malicious: {y_train.sum()} ({y_train.sum()/len(y_train)*100:.1f}%)")
print(f"  Benign: {(1-y_train).sum()} ({(1-y_train).sum()/len(y_train)*100:.1f}%)")

## 6. Build Separate Markov Chains

In [None]:
# Separate training data by class
malicious_tokens = [X_train[i] for i in range(len(X_train)) if y_train[i] == 1]
benign_tokens = [X_train[i] for i in range(len(X_train)) if y_train[i] == 0]

print("Building Markov Chains...")

# Build malicious Markov Chain
malicious_mc = MarkovChain(smoothing=1.0)
malicious_mc.train(malicious_tokens)
print(f"✓ Malicious chain built: {len(malicious_mc.vocabulary)} unique words")

# Build benign Markov Chain
benign_mc = MarkovChain(smoothing=1.0)
benign_mc.train(benign_tokens)
print(f"✓ Benign chain built: {len(benign_mc.vocabulary)} unique words")

print("\nMarkov Chains built successfully!")

## 7. Classification Using Likelihood Ratio

In [None]:
def classify_prompt(tokens, malicious_mc, benign_mc, prior_malicious=0.5):
    """
    Classify a prompt as malicious or benign using likelihood ratio.
    
    Args:
        tokens: List of tokens
        malicious_mc: Malicious Markov Chain
        benign_mc: Benign Markov Chain
        prior_malicious: Prior probability of malicious class
    
    Returns:
        (prediction, malicious_score, benign_score)
    """
    if len(tokens) == 0:
        return 0, 0.0, 0.0
    
    # Calculate log probabilities
    log_prob_malicious = malicious_mc.get_sequence_log_probability(tokens)
    log_prob_benign = benign_mc.get_sequence_log_probability(tokens)
    
    # Add prior probabilities (in log space)
    log_prob_malicious += np.log(prior_malicious + 1e-10)
    log_prob_benign += np.log(1 - prior_malicious + 1e-10)
    
    # Normalize scores for interpretability
    malicious_score = np.exp(log_prob_malicious)
    benign_score = np.exp(log_prob_benign)
    
    # Classify based on higher probability
    prediction = 1 if log_prob_malicious > log_prob_benign else 0
    
    return prediction, malicious_score, benign_score

print("Classification function defined!")

## 8. Evaluate on Test Set

In [None]:
print("Evaluating on test set...")

# Calculate prior probability from training set
prior_malicious = y_train.sum() / len(y_train)
print(f"Prior probability of malicious: {prior_malicious:.3f}")

# Make predictions
predictions = []
malicious_scores = []
benign_scores = []

for tokens in X_test:
    pred, mal_score, ben_score = classify_prompt(tokens, malicious_mc, benign_mc, prior_malicious)
    predictions.append(pred)
    malicious_scores.append(mal_score)
    benign_scores.append(ben_score)

predictions = np.array(predictions)
print("\nEvaluation complete!")

In [None]:
# Calculate metrics
accuracy = accuracy_score(y_test, predictions)
precision = precision_score(y_test, predictions, zero_division=0)
recall = recall_score(y_test, predictions, zero_division=0)
f1 = f1_score(y_test, predictions, zero_division=0)

print("="*50)
print("MARKOV CHAIN PERFORMANCE METRICS")
print("="*50)
print(f"Accuracy:  {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Precision: {precision:.4f} ({precision*100:.2f}%)")
print(f"Recall:    {recall:.4f} ({recall*100:.2f}%)")
print(f"F1-Score:  {f1:.4f} ({f1*100:.2f}%)")
print("="*50)

In [None]:
# Detailed classification report
print("\nDetailed Classification Report:")
print(classification_report(y_test, predictions, target_names=['Benign', 'Malicious']))

In [None]:
# Confusion Matrix
cm = confusion_matrix(y_test, predictions)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Benign', 'Malicious'],
            yticklabels=['Benign', 'Malicious'])
plt.title('Confusion Matrix - Markov Chain Classifier', fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.show()

print(f"\nTrue Negatives (Benign→Benign): {cm[0,0]}")
print(f"False Positives (Benign→Malicious): {cm[0,1]}")
print(f"False Negatives (Malicious→Benign): {cm[1,0]}")
print(f"True Positives (Malicious→Malicious): {cm[1,1]}")

## 9. Identify Most Discriminative Word Sequences

In [None]:
def get_discriminative_transitions(malicious_mc, benign_mc, top_n=20):
    """
    Find word transitions that are most indicative of malicious prompts.
    Uses likelihood ratio: P(transition|malicious) / P(transition|benign)
    """
    discriminative_scores = []
    
    # Get all transitions from malicious chain
    for current_word in malicious_mc.transitions:
        if current_word in ['<START>', '<END>']:
            continue
        
        for next_word in malicious_mc.transitions[current_word]:
            if next_word in ['<START>', '<END>']:
                continue
            
            # Get probabilities from both chains
            prob_malicious = malicious_mc.get_probability(current_word, next_word)
            prob_benign = benign_mc.get_probability(current_word, next_word)
            
            # Calculate likelihood ratio
            ratio = prob_malicious / (prob_benign + 1e-10)
            
            # Count occurrences
            count_malicious = malicious_mc.transitions[current_word][next_word]
            
            # Only consider transitions that appear at least 3 times
            if count_malicious >= 3:
                discriminative_scores.append({
                    'transition': f"{current_word} → {next_word}",
                    'ratio': ratio,
                    'prob_malicious': prob_malicious,
                    'prob_benign': prob_benign,
                    'count': count_malicious
                })
    
    # Sort by likelihood ratio
    discriminative_scores.sort(key=lambda x: x['ratio'], reverse=True)
    
    return discriminative_scores[:top_n]

print("Finding most discriminative word sequences...")
discriminative_transitions = get_discriminative_transitions(malicious_mc, benign_mc, top_n=25)

print("\n" + "="*80)
print("TOP 25 WORD SEQUENCES INDICATING MALICIOUS PROMPTS")
print("="*80)
print(f"{'Rank':<6}{'Word Sequence':<30}{'Likelihood Ratio':<20}{'Occurrences':<15}")
print("-"*80)

for i, item in enumerate(discriminative_transitions, 1):
    print(f"{i:<6}{item['transition']:<30}{item['ratio']:<20.2f}{item['count']:<15}")

print("="*80)

In [None]:
# Visualize top discriminative transitions
top_15 = discriminative_transitions[:15]

transitions = [item['transition'] for item in top_15]
ratios = [item['ratio'] for item in top_15]

plt.figure(figsize=(12, 8))
plt.barh(range(len(transitions)), ratios, color='crimson', alpha=0.7)
plt.yticks(range(len(transitions)), transitions)
plt.xlabel('Likelihood Ratio (Malicious / Benign)', fontsize=12)
plt.title('Top 15 Word Sequences Indicating Malicious Prompts', fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.show()

## 10. Analyze Trigrams for Longer Sequences

In [None]:
# Extract trigrams from both classes
malicious_trigrams = []
benign_trigrams = []

for i in range(len(X_train)):
    trigrams = extract_ngrams(X_train[i], 3)
    if y_train[i] == 1:
        malicious_trigrams.extend(trigrams)
    else:
        benign_trigrams.extend(trigrams)

# Count trigrams
malicious_trigram_counts = Counter(malicious_trigrams)
benign_trigram_counts = Counter(benign_trigrams)

print(f"Total malicious trigrams: {len(malicious_trigrams)}")
print(f"Unique malicious trigrams: {len(malicious_trigram_counts)}")
print(f"\nTotal benign trigrams: {len(benign_trigrams)}")
print(f"Unique benign trigrams: {len(benign_trigram_counts)}")

In [None]:
# Find trigrams unique to or highly indicative of malicious prompts
def get_discriminative_trigrams(malicious_counts, benign_counts, top_n=20, min_count=3):
    """
    Find trigrams that are most indicative of malicious prompts.
    """
    total_malicious = sum(malicious_counts.values())
    total_benign = sum(benign_counts.values())
    
    trigram_scores = []
    
    for trigram, mal_count in malicious_counts.items():
        if mal_count < min_count:
            continue
        
        ben_count = benign_counts.get(trigram, 0)
        
        # Calculate probabilities
        prob_malicious = mal_count / total_malicious
        prob_benign = (ben_count + 1) / (total_benign + len(benign_counts))  # Smoothing
        
        # Likelihood ratio
        ratio = prob_malicious / prob_benign
        
        trigram_scores.append({
            'trigram': ' '.join(trigram),
            'ratio': ratio,
            'malicious_count': mal_count,
            'benign_count': ben_count
        })
    
    # Sort by ratio
    trigram_scores.sort(key=lambda x: x['ratio'], reverse=True)
    
    return trigram_scores[:top_n]

discriminative_trigrams = get_discriminative_trigrams(malicious_trigram_counts, benign_trigram_counts, top_n=25)

print("\n" + "="*90)
print("TOP 25 TRIGRAMS (3-WORD SEQUENCES) INDICATING MALICIOUS PROMPTS")
print("="*90)
print(f"{'Rank':<6}{'Trigram':<40}{'Ratio':<15}{'Mal Count':<12}{'Ben Count'}")
print("-"*90)

for i, item in enumerate(discriminative_trigrams, 1):
    print(f"{i:<6}{item['trigram']:<40}{item['ratio']:<15.2f}{item['malicious_count']:<12}{item['benign_count']}")

print("="*90)

## 11. Example Predictions with Explanations

In [None]:
def explain_prediction(tokens, true_label, malicious_mc, benign_mc):
    """
    Explain why a prompt was classified as malicious or benign.
    """
    pred, mal_score, ben_score = classify_prompt(tokens, malicious_mc, benign_mc)
    
    # Get bigrams
    bigrams = extract_ngrams(tokens, 2)
    
    # Calculate transition probabilities for each bigram
    bigram_ratios = []
    for bg in bigrams:
        if len(bg) == 2:
            prob_mal = malicious_mc.get_probability(bg[0], bg[1])
            prob_ben = benign_mc.get_probability(bg[0], bg[1])
            ratio = prob_mal / (prob_ben + 1e-10)
            bigram_ratios.append((f"{bg[0]} → {bg[1]}", ratio))
    
    # Sort by ratio
    bigram_ratios.sort(key=lambda x: x[1], reverse=True)
    
    print("\n" + "="*80)
    print("PREDICTION EXPLANATION")
    print("="*80)
    print(f"Text: {' '.join(tokens[:50])}..." if len(tokens) > 50 else f"Text: {' '.join(tokens)}")
    print(f"\nTrue Label: {'Malicious' if true_label == 1 else 'Benign'}")
    print(f"Predicted: {'Malicious' if pred == 1 else 'Benign'}")
    print(f"Correct: {'✓ Yes' if pred == true_label else '✗ No'}")
    print(f"\nMalicious Score: {mal_score:.2e}")
    print(f"Benign Score: {ben_score:.2e}")
    
    if len(bigram_ratios) > 0:
        print("\nTop word sequences contributing to maliciousness:")
        for i, (bg, ratio) in enumerate(bigram_ratios[:5], 1):
            print(f"  {i}. '{bg}' (ratio: {ratio:.2f})")
    print("="*80)

# Show examples
print("\n" + "#"*80)
print("EXAMPLE PREDICTIONS")
print("#"*80)

# Find some interesting examples
# 1. Correctly classified malicious
correct_malicious_idx = None
for i in range(len(X_test)):
    if y_test[i] == 1 and predictions[i] == 1:
        correct_malicious_idx = i
        break

if correct_malicious_idx is not None:
    print("\n[1] CORRECTLY CLASSIFIED MALICIOUS PROMPT:")
    explain_prediction(X_test[correct_malicious_idx], y_test[correct_malicious_idx], malicious_mc, benign_mc)

# 2. Correctly classified benign
correct_benign_idx = None
for i in range(len(X_test)):
    if y_test[i] == 0 and predictions[i] == 0:
        correct_benign_idx = i
        break

if correct_benign_idx is not None:
    print("\n[2] CORRECTLY CLASSIFIED BENIGN PROMPT:")
    explain_prediction(X_test[correct_benign_idx], y_test[correct_benign_idx], malicious_mc, benign_mc)

# 3. Misclassified example (if any)
misclassified_idx = None
for i in range(len(X_test)):
    if y_test[i] != predictions[i]:
        misclassified_idx = i
        break

if misclassified_idx is not None:
    print("\n[3] MISCLASSIFIED EXAMPLE:")
    explain_prediction(X_test[misclassified_idx], y_test[misclassified_idx], malicious_mc, benign_mc)

## 12. Summary and Key Findings

In [None]:
print("\n" + "#"*80)
print("SUMMARY OF KEY FINDINGS")
print("#"*80)

print("\n1. MODEL PERFORMANCE:")
print(f"   - Accuracy: {accuracy*100:.2f}%")
print(f"   - Precision: {precision*100:.2f}%")
print(f"   - Recall: {recall*100:.2f}%")
print(f"   - F1-Score: {f1*100:.2f}%")

print("\n2. TOP MALICIOUS WORD SEQUENCES (Bigrams):")
for i, item in enumerate(discriminative_transitions[:5], 1):
    print(f"   {i}. '{item['transition']}' (ratio: {item['ratio']:.2f})")

print("\n3. TOP MALICIOUS TRIGRAMS:")
for i, item in enumerate(discriminative_trigrams[:5], 1):
    print(f"   {i}. '{item['trigram']}' (ratio: {item['ratio']:.2f})")

print("\n4. KEY INSIGHTS:")
print("   - Markov chains successfully capture sequential patterns in prompts")
print("   - Certain word sequences are strong indicators of malicious intent")
print("   - Phrases like 'forget', 'ignore', 'bypass' often appear in malicious prompts")
print("   - The model can identify prompt injection attempts and jailbreak patterns")

print("\n5. COMPUTATIONAL EFFICIENCY:")
print("   - Training: Very fast (seconds)")
print("   - Inference: Extremely fast (milliseconds per prompt)")
print("   - Memory: Lightweight model suitable for deployment")

print("\n" + "#"*80)

## Conclusion

This notebook demonstrates how Markov Chains can be used to:
1. **Identify malicious patterns**: Certain word sequences are strong indicators of malicious prompts
2. **Classify prompts**: Using likelihood ratios between malicious and benign chains
3. **Explain predictions**: Show which word sequences contribute to classification decisions

### Advantages:
- Simple and interpretable
- Fast training and inference
- Provides explainable results (shows which sequences are malicious)
- Low computational requirements (works well on Colab)

### Limitations:
- Only captures sequential dependencies (not long-range context)
- Assumes Markov property (future depends only on current state)
- May struggle with novel attack patterns not seen in training

### Future Improvements:
- Higher-order Markov chains (trigrams, 4-grams)
- Combine with TF-IDF or embeddings
- Ensemble with other classifiers
- Use variable-length n-grams