"""
SentinelAgent ML Model Training - Google Colab Notebook
========================================================

This notebook trains a threat detection model and exports it for use in production.

Run each cell sequentially by clicking the play button or pressing Shift+Enter.
"""

# =============================================================================
# CELL 1: Install Required Packages
# =============================================================================
print(" Installing required packages...")
!pip install -q scikit-learn pandas numpy matplotlib seaborn joblib

print(" Packages installed successfully!")

# =============================================================================
# CELL 2: Import Libraries
# =============================================================================
print("Importing libraries...")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import (classification_report, confusion_matrix, 
                            accuracy_score, precision_recall_fscore_support,
                            roc_curve, auc, roc_auc_score)
from sklearn.tree import DecisionTreeClassifier
import joblib
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Set style for better visualizations
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

print(" Libraries imported successfully!")

# =============================================================================
# CELL 3: Generate Training Data
# =============================================================================
print(" Generating training data...")

def generate_threat_data(n_samples=5000, random_state=42):
    """
    Generate synthetic threat detection training data
    
    In production, replace this with your actual log data from:
    - Network traffic logs
    - System logs
    - Security events
    """
    np.random.seed(random_state)
    
    # Generate features
    data = {
        'packet_size': np.random.randint(64, 1500, n_samples),
        'request_frequency': np.random.randint(1, 150, n_samples),
        'port_number': np.random.choice([21, 22, 23, 80, 443, 3389, 8080, 8443], n_samples),
        'failed_attempts': np.random.randint(0, 15, n_samples),
        'connection_duration': np.random.randint(1, 7200, n_samples),
        'payload_entropy': np.random.uniform(0, 8, n_samples),
        'is_encrypted': np.random.choice([0, 1], n_samples, p=[0.3, 0.7]),
        'geo_risk_score': np.random.uniform(0, 10, n_samples),
        'unusual_port': np.random.choice([0, 1], n_samples, p=[0.85, 0.15]),
        'time_of_day': np.random.randint(0, 24, n_samples),  # 0-23 hours
    }
    
    df = pd.DataFrame(data)
    
    # Generate threat labels based on realistic rules
    threat_level = []
    
    for idx, row in df.iterrows():
        score = 0
        
        # High risk indicators
        if row['failed_attempts'] > 8:
            score += 3
        elif row['failed_attempts'] > 4:
            score += 2
            
        if row['payload_entropy'] > 7.5:
            score += 3
        elif row['payload_entropy'] > 6.5:
            score += 2
            
        if row['request_frequency'] > 100:
            score += 3
        elif row['request_frequency'] > 60:
            score += 2
            
        if row['geo_risk_score'] > 8:
            score += 2
        elif row['geo_risk_score'] > 6:
            score += 1
            
        if row['unusual_port'] == 1:
            score += 2
            
        # Night time activity (suspicious)
        if row['time_of_day'] >= 1 and row['time_of_day'] <= 5:
            score += 1
            
        # Short connections with high frequency
        if row['connection_duration'] < 30 and row['request_frequency'] > 50:
            score += 2
            
        # Classification
        if score >= 8:
            threat_level.append(2)  # Malicious
        elif score >= 4:
            threat_level.append(1)  # Suspicious
        else:
            threat_level.append(0)  # Benign
    
    df['threat_level'] = threat_level
    
    # Add some noise to make it more realistic
    noise_indices = np.random.choice(n_samples, size=int(n_samples * 0.05), replace=False)
    for idx in noise_indices:
        df.at[idx, 'threat_level'] = np.random.choice([0, 1, 2])
    
    return df

# Generate dataset
df = generate_threat_data(n_samples=5000)

print(f" Generated {len(df)} samples")
print(f"\n Dataset Shape: {df.shape}")
print(f"\n Target Distribution:")
print(df['threat_level'].value_counts().sort_index())
print(f"\n   0 = Benign: {(df['threat_level'] == 0).sum()}")
print(f"   1 = Suspicious: {(df['threat_level'] == 1).sum()}")
print(f"   2 = Malicious: {(df['threat_level'] == 2).sum()}")

# Display first few rows
print("\n Sample Data:")
df.head(10)

# =============================================================================
# CELL 4: Exploratory Data Analysis (EDA)
# =============================================================================
print(" Performing Exploratory Data Analysis...")

# Threat level distribution
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# 1. Threat distribution
threat_counts = df['threat_level'].value_counts().sort_index()
axes[0, 0].bar(['Benign', 'Suspicious', 'Malicious'], threat_counts.values, 
               color=['green', 'orange', 'red'], alpha=0.7)
axes[0, 0].set_title('Threat Level Distribution', fontsize=14, fontweight='bold')
axes[0, 0].set_ylabel('Count')
axes[0, 0].grid(axis='y', alpha=0.3)

# 2. Failed attempts by threat level
df.boxplot(column='failed_attempts', by='threat_level', ax=axes[0, 1])
axes[0, 1].set_title('Failed Attempts by Threat Level', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Threat Level')
axes[0, 1].set_ylabel('Failed Attempts')
plt.sca(axes[0, 1])
plt.xticks([1, 2, 3], ['Benign', 'Suspicious', 'Malicious'])

# 3. Payload entropy by threat level
df.boxplot(column='payload_entropy', by='threat_level', ax=axes[1, 0])
axes[1, 0].set_title('Payload Entropy by Threat Level', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Threat Level')
axes[1, 0].set_ylabel('Payload Entropy')
plt.sca(axes[1, 0])
plt.xticks([1, 2, 3], ['Benign', 'Suspicious', 'Malicious'])

# 4. Request frequency by threat level
df.boxplot(column='request_frequency', by='threat_level', ax=axes[1, 1])
axes[1, 1].set_title('Request Frequency by Threat Level', fontsize=14, fontweight='bold')
axes[1, 1].set_xlabel('Threat Level')
axes[1, 1].set_ylabel('Request Frequency')
plt.sca(axes[1, 1])
plt.xticks([1, 2, 3], ['Benign', 'Suspicious', 'Malicious'])

plt.tight_layout()
plt.show()

# Correlation heatmap
print("\n Feature Correlation Heatmap:")
plt.figure(figsize=(12, 8))
correlation_matrix = df.corr()
sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm', 
            center=0, square=True, linewidths=1)
plt.title('Feature Correlation Matrix', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# =============================================================================
# CELL 5: Prepare Data for Training
# =============================================================================
print(" Preparing data for training...")

# Separate features and target
X = df.drop('threat_level', axis=1)
y = df['threat_level']

feature_names = X.columns.tolist()
print(f"Features: {feature_names}")

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"\n Data Split:")
print(f"   Training set: {len(X_train)} samples")
print(f"   Test set: {len(X_test)} samples")

# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

print("\n Data preparation complete!")

# =============================================================================
# CELL 6: Train Multiple Models and Compare
# =============================================================================
print(" Training multiple models for comparison...")

models = {
    'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1),
    'Gradient Boosting': GradientBoostingClassifier(n_estimators=100, random_state=42),
    'Decision Tree': DecisionTreeClassifier(max_depth=10, random_state=42)
}

results = {}

for name, model in models.items():
    print(f"\n{'='*60}")
    print(f"Training {name}...")
    
    # Train
    model.fit(X_train_scaled, y_train)
    
    # Predict
    y_pred = model.predict(X_test_scaled)
    
    # Metrics
    accuracy = accuracy_score(y_test, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='weighted')
    
    results[name] = {
        'model': model,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'predictions': y_pred
    }
    
    print(f" {name} trained successfully!")
    print(f"   Accuracy: {accuracy:.4f}")
    print(f"   Precision: {precision:.4f}")
    print(f"   Recall: {recall:.4f}")
    print(f"   F1-Score: {f1:.4f}")

# Compare models
print(f"\n{'='*60}")
print(" MODEL COMPARISON")
print(f"{'='*60}")

comparison_df = pd.DataFrame({
    'Model': list(results.keys()),
    'Accuracy': [r['accuracy'] for r in results.values()],
    'Precision': [r['precision'] for r in results.values()],
    'Recall': [r['recall'] for r in results.values()],
    'F1-Score': [r['f1'] for r in results.values()]
})

print(comparison_df.to_string(index=False))

# Find best model
best_model_name = max(results.items(), key=lambda x: x[1]['accuracy'])[0]
best_model = results[best_model_name]['model']

print(f"\n Best Model: {best_model_name}")

# =============================================================================
# CELL 7: Detailed Evaluation of Best Model
# =============================================================================
print(f"\n{'='*60}")
print(f"DETAILED EVALUATION: {best_model_name}")
print(f"{'='*60}")

y_pred_best = results[best_model_name]['predictions']

# Classification report
print("\n Classification Report:")
print(classification_report(y_test, y_pred_best, 
                          target_names=['Benign', 'Suspicious', 'Malicious']))

# Confusion matrix
cm = confusion_matrix(y_test, y_pred_best)

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Confusion matrix heatmap
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0],
            xticklabels=['Benign', 'Suspicious', 'Malicious'],
            yticklabels=['Benign', 'Suspicious', 'Malicious'])
axes[0].set_title(f'Confusion Matrix - {best_model_name}', fontsize=14, fontweight='bold')
axes[0].set_ylabel('True Label')
axes[0].set_xlabel('Predicted Label')

# Feature importance (for tree-based models)
if hasattr(best_model, 'feature_importances_'):
    feature_importance = pd.DataFrame({
        'feature': feature_names,
        'importance': best_model.feature_importances_
    }).sort_values('importance', ascending=True)
    
    axes[1].barh(feature_importance['feature'], feature_importance['importance'], 
                 color='steelblue', alpha=0.7)
    axes[1].set_title('Feature Importance', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Importance')
    axes[1].grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.show()

# =============================================================================
# CELL 8: Hyperparameter Tuning (Optional but Recommended)
# =============================================================================
print("\nüîß Performing hyperparameter tuning...")

if best_model_name == 'Random Forest':
    param_grid = {
        'n_estimators': [100, 200],
        'max_depth': [10, 15, 20],
        'min_samples_split': [2, 5],
        'min_samples_leaf': [1, 2]
    }
    
    grid_search = GridSearchCV(
        RandomForestClassifier(random_state=42, n_jobs=-1),
        param_grid,
        cv=3,
        scoring='accuracy',
        n_jobs=-1,
        verbose=1
    )
    
    print(" Training with GridSearchCV (this may take a few minutes)...")
    grid_search.fit(X_train_scaled, y_train)
    
    print(f"\n Best Parameters: {grid_search.best_params_}")
    print(f" Best CV Score: {grid_search.best_score_:.4f}")
    
    # Use best model
    final_model = grid_search.best_estimator_
    
    # Evaluate tuned model
    y_pred_tuned = final_model.predict(X_test_scaled)
    accuracy_tuned = accuracy_score(y_test, y_pred_tuned)
    
    print(f"\n Tuned Model Test Accuracy: {accuracy_tuned:.4f}")
    print(f"   Improvement: {(accuracy_tuned - results[best_model_name]['accuracy']):.4f}")
else:
    final_model = best_model
    print(f"Using {best_model_name} without additional tuning")

# =============================================================================
# CELL 9: Save Model and Artifacts
# =============================================================================
print("\n Saving model and artifacts...")

# Create model directory
import os
model_dir = 'sentinel_model'
os.makedirs(model_dir, exist_ok=True)

# Save model
model_path = f'{model_dir}/threat_model.pkl'
joblib.dump(final_model, model_path)
print(f" Model saved to: {model_path}")

# Save scaler
scaler_path = f'{model_dir}/scaler.pkl'
joblib.dump(scaler, scaler_path)
print(f" Scaler saved to: {scaler_path}")

# Save metadata
metadata = {
    'model_type': best_model_name,
    'feature_names': feature_names,
    'trained_date': datetime.now().isoformat(),
    'accuracy': float(accuracy_score(y_test, final_model.predict(X_test_scaled))),
    'n_samples': len(df),
    'train_test_split': 0.2,
    'threat_levels': {
        0: 'Benign',
        1: 'Suspicious',
        2: 'Malicious'
    }
}

metadata_path = f'{model_dir}/metadata.json'
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)
print(f" Metadata saved to: {metadata_path}")

print(f"\n All files saved in '{model_dir}/' directory")

# =============================================================================
# CELL 10: Test the Saved Model
# =============================================================================
print("\n Testing the saved model...")

# Load model
loaded_model = joblib.load(model_path)
loaded_scaler = joblib.load(scaler_path)

# Test data
test_samples = [
    {
        'packet_size': 1200,
        'request_frequency': 95,
        'port_number': 80,
        'failed_attempts': 8,
        'connection_duration': 250,
        'payload_entropy': 7.8,
        'is_encrypted': 0,
        'geo_risk_score': 8.5,
        'unusual_port': 0,
        'time_of_day': 3
    },
    {
        'packet_size': 850,
        'request_frequency': 12,
        'port_number': 443,
        'failed_attempts': 1,
        'connection_duration': 1200,
        'payload_entropy': 4.2,
        'is_encrypted': 1,
        'geo_risk_score': 2.1,
        'unusual_port': 0,
        'time_of_day': 14
    }
]

print("\n Testing with sample data:\n")

for i, sample in enumerate(test_samples, 1):
    # Prepare data
    sample_df = pd.DataFrame([sample])
    sample_scaled = loaded_scaler.transform(sample_df)
    
    # Predict
    prediction = loaded_model.predict(sample_scaled)[0]
    probabilities = loaded_model.predict_proba(sample_scaled)[0]
    
    threat_labels = ['Benign', 'Suspicious', 'Malicious']
    
    print(f"Sample {i}:")
    print(f"  Failed Attempts: {sample['failed_attempts']}")
    print(f"  Request Frequency: {sample['request_frequency']}")
    print(f"  Payload Entropy: {sample['payload_entropy']:.2f}")
    print(f"  ‚Üí Prediction: {threat_labels[prediction]} (confidence: {probabilities[prediction]:.2%})")
    print(f"  ‚Üí Probabilities: Benign={probabilities[0]:.2%}, "
          f"Suspicious={probabilities[1]:.2%}, Malicious={probabilities[2]:.2%}\n")

# =============================================================================
# CELL 11: Download Model Files
# =============================================================================
print("\n Preparing files for download...")

# Zip the model directory
import shutil
shutil.make_archive('sentinel_model', 'zip', model_dir)

print("\n Model package created: sentinel_model.zip")
print("\n Package contents:")
print("   - threat_model.pkl (trained model)")
print("   - scaler.pkl (feature scaler)")
print("   - metadata.json (model information)")

print("\n" + "="*60)
print(" TRAINING COMPLETE!")
print("="*60)
print("\n Download 'sentinel_model.zip' from the Files panel (left sidebar)")
print(" Extract the files and place them in your project's 'models/' directory")
print("\n Ready to use in production!")

# Download instructions
print("\n To download:")
print("   1. Click the folder icon (üìÅ) in the left sidebar")
print("   2. Find 'sentinel_model.zip'")
print("   3. Right-click ‚Üí Download")
print("   4. Extract and use in your SentinelAgent project")