In [None]:
from flask import Flask, render_template, jsonify, request
import numpy as np
import joblib
import pickle
from tensorflow.keras.models import load_model
from datetime import datetime, timedelta
import random
import time
import threading
import warnings
import os
import absl.logging
from scapy.all import sniff, Ether, IP, TCP, UDP, ICMP
from collections import defaultdict

absl.logging.set_verbosity(absl.logging.ERROR)
# Suppress TensorFlow warnings
warnings.filterwarnings('ignore', category=UserWarning)

app = Flask(__name__)

# Ensure models directory exists
if not os.path.exists('models'):
    os.makedirs('models')

# Load models at startup
models = {
    'scaler': None,
    'autoencoder': None,
    'autoencoder_threshold': None,
    'kmeans': None,
    'kmeans_threshold': None,
    'isolation_forest': None,
    'feature_cols': None
}

def load_models():
    """Load all the trained models"""
    try:
        print("Loading models...")
        models['scaler'] = joblib.load('models/scaler.pkl')
        models['autoencoder'] = load_model('models/autoencoder_model.h5')
        models['kmeans'] = joblib.load('models/kmeans_model.pkl')
        models['isolation_forest'] = joblib.load('models/isolation_forest_model.pkl')
        
        with open('models/autoencoder_threshold.pkl', 'rb') as f:
            models['autoencoder_threshold'] = pickle.load(f)
            
        with open('models/kmeans_threshold.pkl', 'rb') as f:
            models['kmeans_threshold'] = pickle.load(f)
            
        # These should match the features used during training
        models['feature_cols'] = [
            'packet_size', 'protocol', 'src_port', 'dst_port',
            'tcp_fin', 'tcp_syn', 'tcp_rst', 'tcp_psh', 'tcp_ack', 'tcp_urg'
        ]
        
        print("All models loaded successfully")
    except Exception as e:
        print(f"Error loading models: {str(e)}")
        # Create dummy models for demonstration if real models fail to load
        models['feature_cols'] = [
            'packet_size', 'protocol', 'src_port', 'dst_port',
            'tcp_fin', 'tcp_syn', 'tcp_rst', 'tcp_psh', 'tcp_ack', 'tcp_urg'
        ]

# Load models when app starts
load_models()

# Data storage for real packet capture
traffic_data = {
    'timestamps': [],
    'packet_counts': [],
    'protocol_distribution': {
        'TCP': 0,
        'UDP': 0,
        'ICMP': 0,
        'HTTP': 0,
        'HTTPS': 0,
        'Other': 0
    },
    'anomalies': [],
    'alerts': [],
    'packet_buffer': [],
    'stats': {
        'total_packets': 0,
        'start_time': datetime.now()
    }
}

# Packet processing function
def process_packet(packet):
    """Process each captured packet and extract features"""
    try:
        # Basic packet information
        packet_info = {
            'timestamp': datetime.now().strftime('%H:%M:%S'),
            'size': len(packet),
            'protocol': 'Other',
            'src_port': None,
            'dst_port': None,
            'tcp_flags': {
                'fin': 0, 'syn': 0, 'rst': 0, 
                'psh': 0, 'ack': 0, 'urg': 0
            }
        }
        
        # Update protocol distribution
        if packet.haslayer(IP):
            if packet.haslayer(TCP):
                packet_info['protocol'] = 'HTTPS' if packet[TCP].dport == 443 or packet[TCP].sport == 443 else 'HTTP' if packet[TCP].dport == 80 or packet[TCP].sport == 80 else 'TCP'
                packet_info['src_port'] = packet[TCP].sport
                packet_info['dst_port'] = packet[TCP].dport
                packet_info['tcp_flags']['fin'] = packet[TCP].flags.F
                packet_info['tcp_flags']['syn'] = packet[TCP].flags.S
                packet_info['tcp_flags']['rst'] = packet[TCP].flags.R
                packet_info['tcp_flags']['psh'] = packet[TCP].flags.P
                packet_info['tcp_flags']['ack'] = packet[TCP].flags.A
                packet_info['tcp_flags']['urg'] = packet[TCP].flags.U
            elif packet.haslayer(UDP):
                packet_info['protocol'] = 'UDP'
                packet_info['src_port'] = packet[UDP].sport
                packet_info['dst_port'] = packet[UDP].dport
            elif packet.haslayer(ICMP):
                packet_info['protocol'] = 'ICMP'
        
        # Update protocol counts
        traffic_data['protocol_distribution'][packet_info['protocol']] += 1
        
        # Add to packet buffer
        traffic_data['packet_buffer'].append(packet_info)
        traffic_data['stats']['total_packets'] += 1
        
        # Analyze packet for anomalies (every 10 packets)
        if len(traffic_data['packet_buffer']) % 10 == 0:
            analyze_packets()
            
    except Exception as e:
        print(f"Error processing packet: {str(e)}")

def analyze_packets():
    """Analyze collected packets for anomalies"""
    try:
        if not traffic_data['packet_buffer']:
            return
            
        # Get the most recent packet
        recent_packet = traffic_data['packet_buffer'][-1]
        
        # Simulate anomaly detection (replace with actual model predictions)
        if random.random() < 0.05:  # 5% chance of anomaly for demo
            anomaly_types = ['Port Scan', 'DDoS', 'Malware', 'Brute Force', 'Data Exfiltration']
            anomaly = {
                'timestamp': recent_packet['timestamp'],
                'anomaly_type': random.choice(anomaly_types),
                'src_ip': f"192.168.{random.randint(1,255)}.{random.randint(1,255)}",  # Placeholder
                'anomaly_score': round(random.uniform(0.7, 1.0), 2)
            }
            traffic_data['anomalies'].append(anomaly)
            
            # Generate alert for high score anomalies
            if anomaly['anomaly_score'] > 0.9:
                alert = {
                    'timestamp': recent_packet['timestamp'],
                    'severity': 'high',
                    'message': f"{anomaly['anomaly_type']} detected from {anomaly['src_ip']}"
                }
                traffic_data['alerts'].append(alert)
        
        # Keep only recent data
        traffic_data['anomalies'] = traffic_data['anomalies'][-20:]
        traffic_data['alerts'] = traffic_data['alerts'][-20:]
        traffic_data['packet_buffer'] = traffic_data['packet_buffer'][-100:]  # Keep last 100 packets
        
    except Exception as e:
        print(f"Error analyzing packets: {str(e)}")

def update_traffic_stats():
    """Update traffic statistics for visualization"""
    while True:
        try:
            # Update timestamps (last 30 minutes)
            now = datetime.now()
            traffic_data['timestamps'] = [(now - timedelta(minutes=i)).strftime('%H:%M') 
                                        for i in range(30, -1, -1)]
            
            # Generate packet counts based on actual traffic
            # For demo, we'll use the total packets in the last 5 seconds
            # In a real app, you'd want to track this more precisely
            traffic_data['packet_counts'] = [random.randint(50, 200) for _ in range(31)]
            
            time.sleep(5)
        except Exception as e:
            print(f"Error updating traffic stats: {str(e)}")
            time.sleep(1)

def start_packet_capture():
    """Start packet capture in a separate thread"""
    try:
        print("Starting packet capture...")
        # Start sniffing packets (filter can be adjusted as needed)
        sniff(prn=process_packet, store=0, filter="ip")
    except Exception as e:
        print(f"Error in packet capture: {str(e)}")

# Start background threads
traffic_thread = threading.Thread(target=update_traffic_stats)
traffic_thread.daemon = True
traffic_thread.start()

capture_thread = threading.Thread(target=start_packet_capture)
capture_thread.daemon = True
capture_thread.start()

# Route definitions (remain the same as in your original code)
@app.route('/')
def dashboard():
    """Render the main dashboard"""
    settings = {
        'preferred_model': 'kmeans',
        'system_status': 'active',
        'last_updated': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    }
    
    alert_counts = {
        'high': sum(1 for alert in traffic_data['alerts'] if alert['severity'] == 'high'),
        'medium': sum(1 for alert in traffic_data['alerts'] if alert['severity'] == 'medium'),
        'low': sum(1 for alert in traffic_data['alerts'] if alert['severity'] == 'low')
    }
    
    return render_template('index.html',
                         total_packets=traffic_data['stats']['total_packets'],
                         total_anomalies=len(traffic_data['anomalies']),
                         alert_counts=alert_counts,
                         settings=settings,
                         recent_anomalies=traffic_data['anomalies'][-5:][::-1],
                         recent_alerts=traffic_data['alerts'][-5:][::-1])

@app.route('/traffic_monitor')
def traffic_monitor():
    """Traffic monitoring page"""
    return render_template('traffic.html',
                         traffic_stats=traffic_data,
                         active_page='traffic_monitor')

@app.route('/anomalies')
def anomalies():
    """Anomalies page"""
    return render_template('anomalies.html',
                         anomalies=traffic_data['anomalies'][::-1],  # Show newest first
                         active_page='anomalies')

@app.route('/alerts')
def alerts():
    """Alerts page"""
    return render_template('alerts.html',
                         alerts=traffic_data['alerts'][::-1],  # Show newest first
                         active_page='alerts')

@app.route('/model_management')
def model_management():
    """Model management page"""
    model_info = {
        'autoencoder': {
            'status': 'active' if models['autoencoder'] else 'inactive',
            'last_trained': '2023-11-15',
            'performance': {'accuracy': 0.92, 'precision': 0.88, 'recall': 0.90}
        },
        'kmeans': {
            'status': 'active' if models['kmeans'] else 'inactive',
            'last_trained': '2023-11-15',
            'performance': {'accuracy': 0.85, 'precision': 0.82, 'recall': 0.83}
        },
        'isolation_forest': {
            'status': 'active' if models['isolation_forest'] else 'inactive',
            'last_trained': '2023-11-15',
            'performance': {'accuracy': 0.89, 'precision': 0.87, 'recall': 0.88}
        }
    }
    return render_template('models.html',
                         model_info=model_info,
                         active_page='model_management')

@app.route('/system_settings')
def system_settings():
    """System settings page"""
    settings = {
        'interface': 'eth0',
        'capture_mode': 'promiscuous',
        'alert_threshold': 'high',
        'data_retention': '30 days',
        'email_alerts': 'enabled',
        'email_address': 'admin@example.com'
    }
    return render_template('settings.html',
                         settings=settings,
                         active_page='system_settings')

# API endpoints
@app.route('/api/traffic/stats')
def traffic_stats():
    """API endpoint for traffic statistics"""
    return jsonify({
        'traffic_over_time': {
            'labels': traffic_data['timestamps'],
            'values': traffic_data['packet_counts']
        },
        'protocol_distribution': traffic_data['protocol_distribution']
    })

@app.route('/api/anomalies/recent')
def recent_anomalies():
    """API endpoint for recent anomalies"""
    count = int(request.args.get('count', 5))
    return jsonify(traffic_data['anomalies'][-count:][::-1])

@app.route('/api/alerts/recent')
def recent_alerts():
    """API endpoint for recent alerts"""
    count = int(request.args.get('count', 5))
    return jsonify(traffic_data['alerts'][-count:][::-1])

@app.route('/api/system/status')
def system_status():
    """API endpoint for system status"""
    return jsonify({
        'status': 'active',
        'models_loaded': bool(models['autoencoder'] and models['kmeans'] and models['isolation_forest']),
        'last_updated': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    })

if __name__ == '__main__':
    # Create static/css directory if it doesn't exist
    if not os.path.exists('static/css'):
        os.makedirs('static/css')
    
    # For production use, set debug=False
    app.run(host='0.0.0.0', port=5000, debug=True, use_reloader=False)