In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import IsolationForest
from sklearn.svm import OneClassSVM
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.model_selection import ParameterGrid
from collections import defaultdict
import matplotlib.pyplot as plt
import joblib
import warnings
warnings.filterwarnings('ignore')

class EnhancedNASInvalidMessageDetector:
    def __init__(self):
        # ML components
        self.scaler = StandardScaler()
        self.models = {
            'isolation_forest': IsolationForest(contamination=0.01, random_state=42),
            'one_class_svm': OneClassSVM(nu=0.01, kernel='rbf', gamma='auto')
        }
        self.trained_models = {}
        self.feature_columns = None
        self.contamination_rate = 0.01
        # Adjust weights to rely more on Isolation Forest
        self.ensemble_weights = {'isolation_forest': 0.8, 'one_class_svm': 0.2}  
        self.optimal_threshold = None
        
        # Confidence thresholds for reducing false positives
        self.confidence_threshold = 0.8
        self.rule_based_weight = 0.7  # Increased weight for rules
        self.ml_based_weight = 0.3   # Decreased weight for ML
        
        # Valid 5GMM message types
        self.valid_5gmm_types = {
            '65': 'RegistrationRequest',
            '66': 'RegistrationAccept',
            '67': 'RegistrationComplete',
            '68': 'RegistrationReject',
            '69': 'DeregistrationRequestUE',
            '70': 'DeregistrationAcceptUE',
            '71': 'DeregistrationRequestAMF',
            '72': 'DeregistrationAcceptAMF',
            '76': 'ServiceRequest',
            '77': 'ServiceReject',
            '78': 'ServiceAccept',
            '79': 'ControlPlaneServiceRequest',
            '80': 'NetworkSliceSpecificAuthenticationCommand',
            '81': 'NetworkSliceSpecificAuthenticationComplete',
            '82': 'NetworkSliceSpecificAuthenticationResult',
            '84': 'ConfigurationUpdateCommand',
            '85': 'ConfigurationUpdateComplete',
            '86': 'AuthenticationRequest',
            '87': 'AuthenticationResponse',
            '88': 'AuthenticationReject',
            '89': 'AuthenticationFailure',
            '90': 'AuthenticationResult',
            '91': 'IdentityRequest',
            '92': 'IdentityResponse',
            '93': 'SecurityModeCommand',
            '94': 'SecurityModeComplete',
            '95': 'SecurityModeReject',
            '100': 'Status',
            '101': 'Notification',
            '102': 'NotificationResponse',
            '103': 'ULNASTransport',
            '104': 'DLNASTransport'
        }
        
        self.session_states = defaultdict(lambda: {'state': 'IDLE', 'security_activated': False})
        
    def preprocess_data(self, df):
        """Convert string data to proper types"""
        df = df.copy()
        numeric_cols = ['Time', 'AMF_UE_NGAP_ID', 'procedureCode', 'Type', 'Seqn', 'SecHdr', 'EPD', 'spare']
        for col in numeric_cols:
            if col in df.columns:
                df[col] = pd.to_numeric(df[col], errors='coerce')
        return df
    
    def validate_nas_message(self, row):
        """Validate individual NAS message against protocol rules"""
        violations = []
        violation_severity = []  # Track severity of violations
        
        # Check message type validity
        msg_type = str(int(row['Type'])) if not pd.isna(row['Type']) else None
        if msg_type and msg_type not in self.valid_5gmm_types:
            violations.append(f"Invalid message type: {msg_type}")
            violation_severity.append(3)  # High severity
        
        # Check EPD (should be 126 for 5GMM)
        if not pd.isna(row['EPD']) and int(row['EPD']) != 126:
            violations.append(f"Invalid EPD: {int(row['EPD'])}")
            violation_severity.append(2)  # Medium severity
        
        # Check security header constraints
        sec_hdr = row['SecHdr'] if not pd.isna(row['SecHdr']) else None
        if sec_hdr is not None:
            if sec_hdr < 0 or sec_hdr > 5:
                violations.append(f"Invalid security header: {int(sec_hdr)}")
                violation_severity.append(3)  # High severity
        
        # Check sequence number validity
        seq_num = row['Seqn'] if not pd.isna(row['Seqn']) else None
        if seq_num is not None and (seq_num < 0 or seq_num > 255):
            violations.append(f"Invalid sequence number: {seq_num}")
            violation_severity.append(1)  # Low severity
        
        # Check spare field (should be 0)
        spare = row['spare'] if not pd.isna(row['spare']) else None
        if spare is not None and int(spare) != 0:
            violations.append(f"Invalid spare field: {int(spare)}")
            violation_severity.append(1)  # Low severity
        
        # Calculate total severity score
        total_severity = sum(violation_severity) if violation_severity else 0
        
        return violations, total_severity
    
    def engineer_features_enhanced(self, df):
        """Enhanced feature engineering for invalid message detection"""
        # Filter NAS messages
        nas_df = df.copy()
        nas_df = nas_df.sort_values('Time').reset_index(drop=True)
        
        features = pd.DataFrame()
        
        # Basic features
        features['time'] = nas_df['Time']
        features['session_id'] = nas_df['AMF_UE_NGAP_ID']
        features['msg_type'] = nas_df['Type'].fillna(-1)
        features['proc_code'] = nas_df['procedureCode'].fillna(-1)
        features['sequence'] = nas_df['Seqn'].fillna(-1)
        features['sec_hdr'] = nas_df['SecHdr'].fillna(-1)
        features['epd'] = nas_df['EPD'].fillna(-1)
        features['spare'] = nas_df['spare'].fillna(-1)
        
        # Invalid message type indicator
        features['is_valid_msg_type'] = features['msg_type'].astype(str).isin(self.valid_5gmm_types.keys()).astype(int)
        features['is_valid_epd'] = (features['epd'] == 126).astype(int)
        features['is_valid_sec_hdr'] = ((features['sec_hdr'] >= 0) & (features['sec_hdr'] <= 5)).astype(int)
        features['is_valid_seq'] = ((features['sequence'] >= 0) & (features['sequence'] <= 255)).astype(int)
        features['is_valid_spare'] = (features['spare'] == 0).astype(int)
        
        # Add severity scores for each message
        severity_scores = []
        for idx, row in nas_df.iterrows():
            _, severity = self.validate_nas_message(row)
            severity_scores.append(severity)
        features['severity_score'] = severity_scores
        
        # Protocol state features
        for session_id in features['session_id'].unique():
            mask = features['session_id'] == session_id
            session_data = features[mask]
            
            # Time-based features
            features.loc[mask, 'time_diff'] = session_data['time'].diff().fillna(0)
            features.loc[mask, 'time_since_start'] = session_data['time'] - session_data['time'].min()
            
            # Message sequence analysis
            features.loc[mask, 'msg_position'] = range(1, len(session_data) + 1)
            features.loc[mask, 'msg_rate'] = features.loc[mask, 'msg_position'] / (features.loc[mask, 'time_since_start'] + 0.001)
            
            # Protocol flow validation
            prev_msg_type = -1
            for idx in session_data.index:
                curr_msg_type = features.loc[idx, 'msg_type']
                features.loc[idx, 'valid_sequence'] = self.is_valid_message_sequence(prev_msg_type, curr_msg_type)
                prev_msg_type = curr_msg_type
            
            # Security context validation
            security_active = False
            for idx in session_data.index:
                msg_type = features.loc[idx, 'msg_type']
                sec_hdr = features.loc[idx, 'sec_hdr']
                
                if msg_type == 94:  # SecurityModeComplete
                    security_active = True
                
                if security_active and sec_hdr in [0, -1]:
                    features.loc[idx, 'security_violation'] = 1
                else:
                    features.loc[idx, 'security_violation'] = 0
        
        # Anomaly score based on validity checks
        features['validity_score'] = (
            features['is_valid_msg_type'] + 
            features['is_valid_epd'] + 
            features['is_valid_sec_hdr'] + 
            features['is_valid_seq'] +
            features['is_valid_spare']
        ) / 5.0
        
        # Weighted invalid risk score with emphasis on high-severity violations
        features['invalid_risk_score'] = (
            (1 - features['is_valid_msg_type']) * 5 +  # Increased weight
            (1 - features['is_valid_epd']) * 3 +      # Increased weight
            (1 - features['is_valid_sec_hdr']) * 4 +  # Increased weight
            (1 - features['is_valid_seq']) * 0.5 +    # Reduced weight
            (1 - features['is_valid_spare']) * 0.5 +  # Reduced weight
            features['security_violation'] * 2 +
            features['severity_score'] * 0.5  # Add severity score
        )
        
        # Statistical features
        features['msg_type_frequency'] = features.groupby(['session_id', 'msg_type'])['msg_type'].transform('count')
        features['unusual_msg_type'] = (features['msg_type_frequency'] == 1).astype(int)
        
        # Add confidence features
        features['confidence_score'] = 1.0 - (features['invalid_risk_score'] / features['invalid_risk_score'].max())
        
        # Normalize features to reduce SVM sensitivity
        numeric_features = ['time_diff', 'time_since_start', 'msg_rate', 'msg_position']
        for feat in numeric_features:
            if feat in features.columns:
                mean_val = features[feat].mean()
                std_val = features[feat].std()
                if std_val > 0:
                    features[f'{feat}_normalized'] = (features[feat] - mean_val) / std_val
                else:
                    features[f'{feat}_normalized'] = 0
        
        features = features.fillna(0)
        
        return features, nas_df
    
    def is_valid_message_sequence(self, prev_type, curr_type):
        """Check if message sequence is valid according to 5GMM protocol"""
        prev_str = str(int(prev_type)) if prev_type != -1 else None
        curr_str = str(int(curr_type)) if curr_type != -1 else None
        
        # Define valid message sequences (simplified)
        valid_sequences = {
            '65': ['66', '68', '86', '91'],  # After RegistrationRequest
            '86': ['87', '89'],  # After AuthenticationRequest
            '87': ['90', '88'],  # After AuthenticationResponse
            '93': ['94', '95'],  # After SecurityModeCommand
            '94': ['66'],  # After SecurityModeComplete
        }
        
        if prev_str is None or prev_str not in valid_sequences:
            return 1  # Consider valid if no specific rule
        
        return 1 if curr_str in valid_sequences.get(prev_str, []) else 0
    
    def detect_invalid_messages(self, df):
        """Detect invalid NAS messages based on protocol rules"""
        df = self.preprocess_data(df)
        df = df.sort_values('Time').reset_index(drop=True)
        
        invalid_messages = []
        
        for idx, row in df.iterrows():
            violations, severity = self.validate_nas_message(row)
            
            # Include ALL violations for ground truth (no severity filtering)
            if violations:
                invalid_messages.append({
                    'index': idx,
                    'time': row['Time'],
                    'session_id': row['AMF_UE_NGAP_ID'],
                    'msg_type': row['Type'],
                    'proc_code': row['procedureCode'],
                    'violations': violations,
                    'violation_count': len(violations),
                    'severity': severity
                })
        
        return invalid_messages
    
    def optimize_hyperparameters(self, X_train, features_df):
        """Optimize model hyperparameters with strong focus on reducing false positives"""
        print("\n" + "="*70)
        print("HYPERPARAMETER OPTIMIZATION")
        print("="*70)
        
        best_params = {}
        
        # Get indices of actual anomalies based on severity
        anomaly_indices = np.where(features_df['severity_score'] > 0)[0]
        if len(anomaly_indices) == 0:
            # If no anomalies in training, use moderate contamination
            expected_contamination = 0.02  # CHANGED: Balanced value
            print("No anomalies detected in training data, using moderate contamination rate")
        else:
            expected_contamination = len(anomaly_indices) / len(X_train)
        
        # Ensure contamination is within valid range (0.0, 0.5]
        expected_contamination = max(0.01, min(0.5, expected_contamination))
        print(f"Expected contamination rate from training data: {expected_contamination:.2%}")
        
        # Isolation Forest optimization - balanced approach
        # CHANGED: Moderate contamination values
        contamination_values = [0.01, 0.02, 0.03, 0.04]
        
        iso_param_grid = {
            'contamination': contamination_values,
            'n_estimators': [100, 150],
            'max_samples': ['auto', 0.9],
            'max_features': [0.8, 1.0],
            'bootstrap': [True]  # CHANGED: Back to True for better generalization
        }
        
        print("\nOptimizing Isolation Forest...")
        best_score = float('-inf')
        best_iso_params = None
        
        for params in ParameterGrid(iso_param_grid):
            model = IsolationForest(random_state=42, **params)
            model.fit(X_train)
            
            scores = model.score_samples(X_train)
            predictions = model.predict(X_train)
            
            # Calculate false positive rate on training data
            normal_indices = np.where(features_df['severity_score'] == 0)[0]
            if len(normal_indices) > 0:
                fp_rate = (predictions[normal_indices] == -1).sum() / len(normal_indices)
            else:
                fp_rate = 0
            
            # Balanced scoring
            score_std = np.std(scores)
            # Only moderate penalty for FP
            combined_score = score_std - (fp_rate * 10)
            
            if combined_score > best_score:
                best_score = combined_score
                best_iso_params = params
        
        print(f"Best Isolation Forest params: {best_iso_params}")
        best_params['isolation_forest'] = best_iso_params
        
        # One-Class SVM optimization - focusing on reducing false positives
        # Train on a subset of normal data to reduce overfitting
        normal_indices = np.where(features_df['severity_score'] == 0)[0]
        if len(normal_indices) > 100:
            # Use stratified sampling
            train_size = int(0.8 * len(normal_indices))
            train_indices = np.random.choice(normal_indices, train_size, replace=False)
            X_train_svm = X_train[train_indices]
        else:
            X_train_svm = X_train
        
        # CHANGED: Very conservative nu values but with careful tuning
        svm_param_grid = {
            'nu': [0.001, 0.002, 0.003, 0.005],  # Very low nu
            'kernel': ['rbf'],
            'gamma': ['auto', 'scale'],  # Let sklearn choose optimal gamma
            'tol': [0.001],
            'shrinking': [True, False]  # ADDED: Can help with generalization
        }
        
        print("\nOptimizing One-Class SVM for minimal false positives...")
        best_score = float('-inf')
        best_svm_params = None
        best_fp_rate = 1.0
        
        for params in ParameterGrid(svm_param_grid):
            try:
                model = OneClassSVM(**params)
                model.fit(X_train_svm)
                
                # Evaluate on full training set
                predictions = model.predict(X_train)
                
                # Calculate metrics
                if len(normal_indices) > 0:
                    fp_rate = (predictions[normal_indices] == -1).sum() / len(normal_indices)
                else:
                    fp_rate = 0
                
                # Extremely strict on false positives
                if fp_rate < 0.02:  # Max 2% FP rate
                    score = 1.0 - (fp_rate * 50)  # Heavy penalty
                else:
                    score = -fp_rate * 100
                
                if score > best_score:
                    best_score = score
                    best_svm_params = params
                    best_fp_rate = fp_rate
                    
            except Exception as e:
                continue
        
        print(f"Best One-Class SVM params: {best_svm_params}")
        print(f"Best SVM FP rate: {best_fp_rate:.2%}")
        best_params['one_class_svm'] = best_svm_params
        
        # Update models with best parameters
        self.models['isolation_forest'] = IsolationForest(random_state=42, **best_params['isolation_forest'])
        self.models['one_class_svm'] = OneClassSVM(**best_params['one_class_svm'])
        
        self.contamination_rate = best_params['isolation_forest']['contamination']
        
        return best_params
    
    def ensemble_predict_with_confidence(self, X_test, features_df):
        """Enhanced ensemble prediction with improved detection"""
        predictions = {}
        scores = {}
        
        for model_name, model in self.trained_models.items():
            predictions[model_name] = model.predict(X_test)
            
            if hasattr(model, 'decision_function'):
                scores[model_name] = model.decision_function(X_test)
            else:
                scores[model_name] = model.score_samples(X_test)
        
        # Make predictions based on smart logic
        ensemble_predictions = np.ones(len(X_test))
        
        # Decision logic that relies heavily on protocol violations
        for i in range(len(X_test)):
            iso_pred = predictions['isolation_forest'][i]
            svm_pred = predictions['one_class_svm'][i]
            severity = features_df['severity_score'].values[i]
            
            # CHANGED: Rule-based approach with ML support
            # 1. Always flag high severity violations (EPD, message type)
            if severity >= 2:
                ensemble_predictions[i] = -1
            # 2. For low severity, only flag if Isolation Forest agrees
            # (ignore SVM due to high FP rate)
            elif severity == 1 and iso_pred == -1:
                ensemble_predictions[i] = -1
            # 3. If no violation but BOTH models strongly agree, consider it
            elif severity == 0 and iso_pred == -1 and svm_pred == -1:
                # Double-check with score percentiles
                iso_score = scores['isolation_forest'][i]
                svm_score = scores['one_class_svm'][i]
                iso_threshold = np.percentile(scores['isolation_forest'], 1)  # Very strict
                svm_threshold = np.percentile(scores['one_class_svm'], 1)  # Very strict
                
                if iso_score < iso_threshold and svm_score < svm_threshold:
                    ensemble_predictions[i] = -1
        
        # Calculate combined score for reporting
        ml_score = np.zeros(len(X_test))
        for model_name in self.ensemble_weights:
            if model_name in scores:
                normalized_score = (scores[model_name] - scores[model_name].min()) / \
                                 (scores[model_name].max() - scores[model_name].min() + 1e-10)
                ml_score += normalized_score * self.ensemble_weights[model_name]
        
        rule_based_score = features_df['severity_score'].values / (features_df['severity_score'].max() + 0.001)
        combined_score = (self.ml_based_weight * ml_score + self.rule_based_weight * rule_based_score)
        
        return ensemble_predictions, combined_score
    
    def train(self, train_data, optimize_params=True):
        """Train ML models on clean training data"""
        print("="*70)
        print("TRAINING PHASE - INVALID MESSAGE DETECTION")
        print("="*70)
        
        train_df = self.preprocess_data(train_data)
        print(f"Training data shape: {train_df.shape}")
        
        features, nas_df = self.engineer_features_enhanced(train_df)
        
        print(f"Training on {len(nas_df)} NAS messages")
        print(f"Feature dimensions: {features.shape}")
        
        self.feature_columns = features.columns.tolist()
        X_train = self.scaler.fit_transform(features)
        
        # Optimize hyperparameters
        if optimize_params:
            best_params = self.optimize_hyperparameters(X_train, features)
        
        # Train models - CHANGED: Simplified training
        for model_name, model in self.models.items():
            print(f"\nTraining {model_name}...")
            model.fit(X_train)
            self.trained_models[model_name] = model
            
            # Test on training data to check anomaly rate
            predictions = model.predict(X_train)
            anomaly_rate = (predictions == -1).sum() / len(predictions)
            print(f"  Training anomaly rate: {anomaly_rate:.2%}")
        
        # CHANGED: More conservative ensemble weights
        self.ensemble_weights = {'isolation_forest': 0.6, 'one_class_svm': 0.4}
        self.rule_based_weight = 0.9  # CHANGED: Increased rule weight
        self.ml_based_weight = 0.1   # CHANGED: Decreased ML weight
        
        print("\nTraining complete!")
    
    def optimize_ensemble_weights(self, X_train, features_df):
        """Optimize ensemble weights based on model performance"""
        print("\nOptimizing ensemble weights...")
        
        scores = {}
        predictions = {}
        
        for model_name, model in self.trained_models.items():
            predictions[model_name] = model.predict(X_train)
            if hasattr(model, 'decision_function'):
                scores[model_name] = model.decision_function(X_train)
            else:
                scores[model_name] = model.score_samples(X_train)
        
        # Calculate false positive tendency for each model
        # based on normal samples (severity_score == 0)
        fp_rates = {}
        detection_rates = {}
        
        normal_mask = features_df['severity_score'] == 0
        anomaly_mask = features_df['severity_score'] > 0
        
        for model_name in predictions:
            # False positive rate
            if normal_mask.sum() > 0:
                fp_rate = (predictions[model_name][normal_mask] == -1).sum() / normal_mask.sum()
            else:
                fp_rate = 0
            fp_rates[model_name] = fp_rate
            
            # Detection rate
            detection_rates[model_name] = (predictions[model_name] == -1).sum() / len(predictions[model_name])
        
        print(f"False positive rates: {fp_rates}")
        print(f"Detection rates: {detection_rates}")
        
        # Weight heavily based on false positive rate
        # Lower FP rate = higher weight
        weights = {}
        for name in fp_rates:
            if fp_rates[name] < 0.01:
                weights[name] = 1.0  # Excellent
            elif fp_rates[name] < 0.1:
                weights[name] = 0.5  # Good
            elif fp_rates[name] < 0.5:
                weights[name] = 0.1  # Poor
            else:
                weights[name] = 0.01  # Very poor
        
        # Normalize weights
        total_weight = sum(weights.values())
        if total_weight > 0:
            self.ensemble_weights = {name: w/total_weight for name, w in weights.items()}
        else:
            # Fallback to default weights
            self.ensemble_weights = {'isolation_forest': 0.8, 'one_class_svm': 0.2}
        
        print(f"Optimized ensemble weights: {self.ensemble_weights}")
    
    def test(self, test_data, use_ensemble=True):
        """Test ML models and compare with ground truth"""
        print("\n" + "="*70)
        print("TESTING PHASE - INVALID MESSAGE DETECTION")
        print("="*70)
        
        # Get ground truth invalid messages
        invalid_messages = self.detect_invalid_messages(test_data)
        print(f"\nGround truth: {len(invalid_messages)} invalid messages found")
        
        # Prepare test data
        test_df = self.preprocess_data(test_data)
        features, nas_df = self.engineer_features_enhanced(test_df)
        features_subset = features[self.feature_columns]
        X_test = self.scaler.transform(features_subset)
        
        # Get true labels
        y_true = np.zeros(len(nas_df))
        for invalid_msg in invalid_messages:
            y_true[invalid_msg['index']] = 1
        
        results = {}
        
        # Test individual models
        for model_name, model in self.trained_models.items():
            predictions = model.predict(X_test)
            scores = model.decision_function(X_test) if hasattr(model, 'decision_function') else model.score_samples(X_test)
            
            # Convert predictions
            y_pred = (predictions == -1).astype(int)
            
            # Calculate metrics
            tp = np.sum((y_true == 1) & (y_pred == 1))
            fp = np.sum((y_true == 0) & (y_pred == 1))
            fn = np.sum((y_true == 1) & (y_pred == 0))
            tn = np.sum((y_true == 0) & (y_pred == 0))
            
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            
            results[model_name] = {
                'predictions': y_pred,
                'scores': scores,
                'metrics': {
                    'true_positives': tp,
                    'false_positives': fp,
                    'false_negatives': fn,
                    'true_negatives': tn,
                    'precision': precision,
                    'recall': recall,
                    'f1_score': f1
                },
                'nas_df': nas_df
            }
        
        # Add enhanced ensemble results
        if use_ensemble:
            ensemble_predictions, ensemble_scores = self.ensemble_predict_with_confidence(X_test, features)
            y_pred_ensemble = (ensemble_predictions == -1).astype(int)
            
            tp = np.sum((y_true == 1) & (y_pred_ensemble == 1))
            fp = np.sum((y_true == 0) & (y_pred_ensemble == 1))
            fn = np.sum((y_true == 1) & (y_pred_ensemble == 0))
            tn = np.sum((y_true == 0) & (y_pred_ensemble == 0))
            
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            
            results['ensemble'] = {
                'predictions': y_pred_ensemble,
                'scores': ensemble_scores,
                'metrics': {
                    'true_positives': tp,
                    'false_positives': fp,
                    'false_negatives': fn,
                    'true_negatives': tn,
                    'precision': precision,
                    'recall': recall,
                    'f1_score': f1
                },
                'nas_df': nas_df
            }
        
        return results, invalid_messages
    
    def print_metrics(self, results):
        """Print detailed metrics for each model"""
        print("\n" + "="*70)
        print("MODEL PERFORMANCE METRICS")
        print("="*70)
        
        for model_name, result in results.items():
            metrics = result['metrics']
            print(f"\n{model_name.upper()}:")
            print("-" * 40)
            print(f"  True Positives:  {metrics['true_positives']}")
            print(f"  False Positives: {metrics['false_positives']}")
            print(f"  False Negatives: {metrics['false_negatives']}")
            print(f"  True Negatives:  {metrics['true_negatives']}")
            print(f"  Precision:       {metrics['precision']:.2%}")
            print(f"  Recall:          {metrics['recall']:.2%}")
            print(f"  F1-Score:        {metrics['f1_score']:.2%}")
    
    def analyze_invalid_messages(self, test_data, invalid_messages):
        """Detailed analysis of detected invalid messages"""
        if not invalid_messages:
            print("\n✓ No invalid messages to analyze")
            return None
            
        print("\n" + "="*70)
        print("DETAILED INVALID MESSAGE ANALYSIS")
        print("="*70)
        
        test_df = self.preprocess_data(test_data)
        
        print(f"\nFound {len(invalid_messages)} invalid message(s):")
        print("\n" + "-"*100)
        
        # Group by violation type
        violation_types = defaultdict(int)
        for msg in invalid_messages:
            for violation in msg['violations']:
                violation_types[violation] += 1
        
        print("\nViolation Summary:")
        for violation, count in sorted(violation_types.items(), key=lambda x: x[1], reverse=True):
            print(f"  {violation}: {count} occurrences")
        
        # Show detailed analysis for first 5 invalid messages
        print("\n" + "-"*100)
        print("\nDetailed Analysis (first 5 messages):")
        
        for i, invalid_msg in enumerate(invalid_messages[:5]):
            print(f"\nINVALID MESSAGE #{i+1}")
            print("-"*50)
            
            msg = test_df.iloc[invalid_msg['index']]
            
            print(f"  Time:        {msg['Time']:.6f}s")
            print(f"  Session:     {int(msg['AMF_UE_NGAP_ID'])}")
            print(f"  Message Type: {int(msg['Type']) if not pd.isna(msg['Type']) else 'NaN'}", end="")
            
            msg_type_str = str(int(msg['Type'])) if not pd.isna(msg['Type']) else None
            if msg_type_str in self.valid_5gmm_types:
                print(f" ({self.valid_5gmm_types[msg_type_str]})")
            else:
                print(" (UNKNOWN)")
            
            print(f"  Proc Code:   {int(msg['procedureCode']) if not pd.isna(msg['procedureCode']) else 'NaN'}")
            print(f"  Sequence:    {int(msg['Seqn']) if not pd.isna(msg['Seqn']) else 'NaN'}")
            print(f"  SecHdr:      {int(msg['SecHdr']) if not pd.isna(msg['SecHdr']) else 'NaN'}")
            print(f"  EPD:         {int(msg['EPD']) if not pd.isna(msg['EPD']) else 'NaN'}")
            print(f"  Spare:       {int(msg['spare']) if not pd.isna(msg['spare']) else 'NaN'}")
            print(f"  Severity:    {invalid_msg['severity']}")
            
            print(f"\n  Violations ({invalid_msg['violation_count']}):")
            for violation in invalid_msg['violations']:
                print(f"    ⚠️  {violation}")
        
        return invalid_messages
    
    def visualize_anomalies(self, results, invalid_messages):
        """Create visualizations for invalid message detection results"""
        n_models = len(results)
        
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        axes = axes.flatten()
        
        # 1. Detection comparison bar chart
        ax1 = axes[0]
        
        normal_counts = []
        invalid_counts = []
        
        for model_name, result in results.items():
            predictions = result['predictions']
            normal = np.sum(predictions == 0)
            invalid = np.sum(predictions == 1)
            
            normal_counts.append(normal)
            invalid_counts.append(invalid)
        
        model_names = list(results.keys())
        x = np.arange(len(model_names))
        width = 0.35
        
        bars1 = ax1.bar(x - width/2, normal_counts, width, label='Valid Messages', color='lightgreen')
        bars2 = ax1.bar(x + width/2, invalid_counts, width, label='Invalid Messages', color='lightcoral')
        
        # Add count labels
        for i, (normal, invalid) in enumerate(zip(normal_counts, invalid_counts)):
            ax1.text(i - width/2, normal + 0.5, str(normal), ha='center', va='bottom')
            ax1.text(i + width/2, invalid + 0.5, str(invalid), ha='center', va='bottom')
        
        ax1.set_xlabel('Models')
        ax1.set_ylabel('Number of Messages')
        ax1.set_title('Valid vs Invalid Message Detection by Model')
        ax1.set_xticks(x)
        ax1.set_xticklabels(model_names, rotation=15)
        ax1.legend()
        ax1.grid(True, alpha=0.3, axis='y')
        
        # 2. Violation type distribution
        ax2 = axes[1]
        
        violation_types = defaultdict(int)
        for msg in invalid_messages:
            for violation in msg['violations']:
                # Simplify violation descriptions for plotting
                if "Invalid message type" in violation:
                    violation_types["Invalid Message Type"] += 1
                elif "Invalid EPD" in violation:
                    violation_types["Invalid EPD"] += 1
                elif "Invalid security header" in violation:
                    violation_types["Invalid Security Header"] += 1
                elif "Invalid sequence number" in violation:
                    violation_types["Invalid Sequence Number"] += 1
                elif "Invalid spare field" in violation:
                    violation_types["Invalid Spare Field"] += 1
        
        if violation_types:
            violations = list(violation_types.keys())
            counts = list(violation_types.values())
            
            y_pos = np.arange(len(violations))
            ax2.barh(y_pos, counts, color='orange')
            ax2.set_yticks(y_pos)
            ax2.set_yticklabels(violations)
            ax2.set_xlabel('Count')
            ax2.set_title('Distribution of Violation Types')
            ax2.grid(True, alpha=0.3, axis='x')
            
            # Add count labels
            for i, count in enumerate(counts):
                ax2.text(count + 0.1, i, str(count), va='center')
        else:
            ax2.text(0.5, 0.5, 'No violations detected', 
                    ha='center', va='center', transform=ax2.transAxes)
            ax2.set_title('Distribution of Violation Types')
        
        # 3. Performance metrics comparison
        ax3 = axes[2]
        
        metrics_data = []
        for model_name, result in results.items():
            m = result['metrics']
            metrics_data.append({
                'model': model_name,
                'precision': m['precision'] * 100,
                'recall': m['recall'] * 100,
                'f1_score': m['f1_score'] * 100
            })
        
        y_positions = np.arange(len(model_names))
        
        for i, data in enumerate(metrics_data):
            ax3.scatter(data['precision'], i, s=100, label='Precision' if i == 0 else '', 
                       color='blue', marker='o')
            ax3.scatter(data['recall'], i, s=100, label='Recall' if i == 0 else '', 
                       color='green', marker='s')
            ax3.scatter(data['f1_score'], i, s=100, label='F1-Score' if i == 0 else '', 
                       color='red', marker='^')
            
            # Add value labels
            ax3.text(data['precision'] + 1, i, f"{data['precision']:.1f}%", 
                    va='center', fontsize=9, color='blue')
            ax3.text(data['recall'] + 1, i + 0.1, f"{data['recall']:.1f}%", 
                    va='center', fontsize=9, color='green')
            ax3.text(data['f1_score'] + 1, i - 0.1, f"{data['f1_score']:.1f}%", 
                    va='center', fontsize=9, color='red')
        
        ax3.set_yticks(y_positions)
        ax3.set_yticklabels(model_names)
        ax3.set_xlabel('Performance (%)')
        ax3.set_title('Model Performance Metrics')
        ax3.set_xlim(-5, 110)
        ax3.legend(loc='lower right')
        ax3.grid(True, alpha=0.3, axis='x')
        
        # Add summary
        fig.suptitle(f'Invalid NAS Message Detection Analysis - {len(invalid_messages)} Invalid Message(s) Found', 
                    fontsize=16, fontweight='bold')
        
        plt.tight_layout()
        plt.savefig('invalid_message_detection_analysis.png', dpi=300, bbox_inches='tight')
        print("\n✓ Visualizations saved as 'invalid_message_detection_analysis.png'")
        plt.show()
    
    def save_models(self, filepath='nas_invalid_message_models.pkl'):
        """Save trained models and parameters"""
        model_data = {
            'models': self.trained_models,
            'scaler': self.scaler,
            'feature_columns': self.feature_columns,
            'ensemble_weights': self.ensemble_weights,
            'optimal_threshold': self.optimal_threshold,
            'contamination_rate': self.contamination_rate,
            'valid_5gmm_types': self.valid_5gmm_types,
            'confidence_threshold': self.confidence_threshold,
            'rule_based_weight': self.rule_based_weight,
            'ml_based_weight': self.ml_based_weight
        }
        joblib.dump(model_data, filepath)
        print(f"\n✓ Models saved to {filepath}")
    
    def load_models(self, filepath='nas_invalid_message_models.pkl'):
        """Load trained models and parameters"""
        model_data = joblib.load(filepath)
        self.trained_models = model_data['models']
        self.scaler = model_data['scaler']
        self.feature_columns = model_data['feature_columns']
        self.ensemble_weights = model_data.get('ensemble_weights', {'isolation_forest': 0.5, 'one_class_svm': 0.5})
        self.optimal_threshold = model_data.get('optimal_threshold', None)
        self.contamination_rate = model_data.get('contamination_rate', 0.01)
        self.valid_5gmm_types = model_data.get('valid_5gmm_types', self.valid_5gmm_types)
        self.confidence_threshold = model_data.get('confidence_threshold', 0.8)
        self.rule_based_weight = model_data.get('rule_based_weight', 0.6)
        self.ml_based_weight = model_data.get('ml_based_weight', 0.4)
        print(f"\n✓ Models loaded from {filepath}")

# Main execution
if __name__ == "__main__":
    # Initialize detector
    detector = EnhancedNASInvalidMessageDetector()
    
    # Train on clean data
    print("Loading training data...")
    train_data = pd.read_csv("BenignData.csv", dtype=str, sep=";")
    print(f"Training data shape: {train_data.shape}")
    
    print("\nStarting training for invalid message detection...")
    detector.train(train_data, optimize_params=True)
    
    # Save models
    detector.save_models('nas_invalid_message_models.pkl')
    
    # Test on potentially contaminated data
    print("\nLoading test data...")
    test_data = pd.read_csv("InvalidheaderTest.csv", dtype=str, sep=";")
    print(f"Test data shape: {test_data.shape}")
    
    results, invalid_messages = detector.test(test_data, use_ensemble=True)
    
    # Analyze invalid messages
    detector.analyze_invalid_messages(test_data, invalid_messages)
    
    # Print metrics
    detector.print_metrics(results)
    
    # Create visualizations
    detector.visualize_anomalies(results, invalid_messages)
    
    # Demonstrate model loading
    print("\n" + "="*70)
    print("DEMONSTRATING MODEL LOAD")
    print("="*70)
    
    new_detector = EnhancedNASInvalidMessageDetector()
    new_detector.load_models('nas_invalid_message_models.pkl')
    
    # Test with loaded models
    results2, _ = new_detector.test(test_data, use_ensemble=True)
    print("\n✓ Loaded models produce consistent results!")
    
    print("\n" + "="*70)
    print("INVALID MESSAGE DETECTION COMPLETE")
    print("="*70)