#!/usr/bin/env python3
"""
Educational Network Packet Sniffer
==================================

IMPORTANT ETHICAL NOTICE:
This tool is for EDUCATIONAL PURPOSES ONLY and should only be used on:
- Your own network
- Networks you own or have explicit permission to monitor
- Lab environments for learning cybersecurity

Unauthorized network monitoring is illegal and unethical.
Always comply with local laws and obtain proper permissions.
"""

import socket
import struct
import textwrap
import argparse
import sys
import time
from datetime import datetime


class PacketSniffer:
    """Educational packet sniffer for network analysis learning."""
    
    def __init__(self, interface=None, filter_protocol=None):
        self.interface = interface
        self.filter_protocol = filter_protocol
        self.packet_count = 0
        self.start_time = time.time()
        
        # Protocol numbers
        self.protocols = {
            1: 'ICMP',
            6: 'TCP',
            17: 'UDP',
            2: 'IGMP',
            89: 'OSPF'
        }
    
    def create_socket(self):
        """Create raw socket for packet capture."""
        try:
            # Create raw socket
            if sys.platform.startswith('win'):
                # Windows
                sock = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IP)
                sock.bind((socket.gethostbyname(socket.gethostname()), 0))
                sock.setsockopt(socket.IPPROTO_IP, socket.IP_HDRINCL, 1)
                # Enable promiscuous mode on Windows
                sock.ioctl(socket.SIO_RCVALL, socket.RCVALL_ON)
            else:
                # Linux/Unix
                sock = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.ntohs(3))
            
            return sock
        except PermissionError:
            print("Error: This tool requires administrator/root privileges to capture packets.")
            print("On Linux: sudo python3 packet_sniffer.py")
            print("On Windows: Run as Administrator")
            sys.exit(1)
        except Exception as e:
            print(f"Error creating socket: {e}")
            sys.exit(1)
    
    def parse_ethernet_header(self, data):
        """Parse Ethernet header (Linux only)."""
        eth_header = struct.unpack('!6s6sH', data[:14])
        dest_mac = ':'.join(f'{b:02x}' for b in eth_header[0])
        src_mac = ':'.join(f'{b:02x}' for b in eth_header[1])
        eth_type = socket.ntohs(eth_header[2])
        
        return {
            'dest_mac': dest_mac,
            'src_mac': src_mac,
            'type': eth_type,
            'data': data[14:]
        }
    
    def parse_ip_header(self, data):
        """Parse IP header from packet data."""
        # Unpack the first 20 bytes of IP header
        ip_header = struct.unpack('!BBHHHBBH4s4s', data[:20])
        
        version_ihl = ip_header[0]
        version = version_ihl >> 4
        ihl = version_ihl & 0xF
        header_length = ihl * 4
        
        ttl = ip_header[5]
        protocol = ip_header[6]
        src_addr = socket.inet_ntoa(ip_header[8])
        dest_addr = socket.inet_ntoa(ip_header[9])
        
        return {
            'version': version,
            'header_length': header_length,
            'ttl': ttl,
            'protocol': protocol,
            'src_ip': src_addr,
            'dest_ip': dest_addr,
            'data': data[header_length:]
        }
    
    def parse_tcp_header(self, data):
        """Parse TCP header."""
        tcp_header = struct.unpack('!HHLLBBHHH', data[:20])
        
        src_port = tcp_header[0]
        dest_port = tcp_header[1]
        seq_num = tcp_header[2]
        ack_num = tcp_header[3]
        flags = tcp_header[5]
        
        # Extract flag bits
        flag_urg = (flags & 32) >> 5
        flag_ack = (flags & 16) >> 4
        flag_psh = (flags & 8) >> 3
        flag_rst = (flags & 4) >> 2
        flag_syn = (flags & 2) >> 1
        flag_fin = flags & 1
        
        header_length = (tcp_header[4] >> 4) * 4
        
        return {
            'src_port': src_port,
            'dest_port': dest_port,
            'seq': seq_num,
            'ack': ack_num,
            'flags': {
                'URG': flag_urg,
                'ACK': flag_ack,
                'PSH': flag_psh,
                'RST': flag_rst,
                'SYN': flag_syn,
                'FIN': flag_fin
            },
            'header_length': header_length,
            'data': data[header_length:]
        }
    
    def parse_udp_header(self, data):
        """Parse UDP header."""
        udp_header = struct.unpack('!HHHH', data[:8])
        
        return {
            'src_port': udp_header[0],
            'dest_port': udp_header[1],
            'length': udp_header[2],
            'checksum': udp_header[3],
            'data': data[8:]
        }
    
    def parse_icmp_header(self, data):
        """Parse ICMP header."""
        icmp_header = struct.unpack('!BBH', data[:4])
        
        icmp_types = {
            0: 'Echo Reply',
            3: 'Destination Unreachable',
            8: 'Echo Request',
            11: 'Time Exceeded',
            12: 'Parameter Problem'
        }
        
        return {
            'type': icmp_header[0],
            'type_name': icmp_types.get(icmp_header[0], 'Unknown'),
            'code': icmp_header[1],
            'checksum': icmp_header[2],
            'data': data[4:]
        }
    
    def format_payload(self, data, max_bytes=64):
        """Format payload data for display."""
        if not data:
            return "No payload data"
        
        # Limit data for display
        display_data = data[:max_bytes]
        
        # Create hex dump
        hex_str = ' '.join(f'{b:02x}' for b in display_data)
        
        # Create ASCII representation
        ascii_str = ''.join(chr(b) if 32 <= b <= 126 else '.' for b in display_data)
        
        return f"Hex: {hex_str}\nASCII: {ascii_str}"
    
    def display_packet_info(self, packet_info):
        """Display formatted packet information."""
        print(f"\n{'='*80}")
        print(f"Packet #{self.packet_count} - {datetime.now().strftime('%H:%M:%S.%f')[:-3]}")
        print(f"{'='*80}")
        
        # Ethernet info (if available)
        if 'ethernet' in packet_info:
            eth = packet_info['ethernet']
            print(f"Ethernet: {eth['src_mac']} → {eth['dest_mac']} (Type: 0x{eth['type']:04x})")
        
        # IP info
        if 'ip' in packet_info:
            ip = packet_info['ip']
            protocol_name = self.protocols.get(ip['protocol'], f"Unknown({ip['protocol']})")
            print(f"IP: {ip['src_ip']} → {ip['dest_ip']} (Protocol: {protocol_name}, TTL: {ip['ttl']})")
        
        # Protocol-specific info
        if 'tcp' in packet_info:
            tcp = packet_info['tcp']
            flags = [flag for flag, value in tcp['flags'].items() if value]
            flags_str = ','.join(flags) if flags else 'None'
            print(f"TCP: Port {tcp['src_port']} → {tcp['dest_port']} (Flags: {flags_str})")
            print(f"     Seq: {tcp['seq']}, Ack: {tcp['ack']}")
            
        elif 'udp' in packet_info:
            udp = packet_info['udp']
            print(f"UDP: Port {udp['src_port']} → {udp['dest_port']} (Length: {udp['length']})")
            
        elif 'icmp' in packet_info:
            icmp = packet_info['icmp']
            print(f"ICMP: Type {icmp['type']} ({icmp['type_name']}) Code {icmp['code']}")
        
        # Payload info
        if 'payload' in packet_info and packet_info['payload']:
            print(f"\nPayload ({len(packet_info['payload'])} bytes):")
            print(self.format_payload(packet_info['payload']))
    
    def should_capture_packet(self, protocol):
        """Check if packet should be captured based on filter."""
        if self.filter_protocol is None:
            return True
        
        protocol_name = self.protocols.get(protocol, '').lower()
        return protocol_name == self.filter_protocol.lower()
    
    def sniff_packets(self, count=0):
        """Main packet sniffing loop."""
        print(f"\n{'='*80}")
        print("EDUCATIONAL PACKET SNIFFER STARTED")
        print(f"{'='*80}")
        print(f"Filter: {self.filter_protocol or 'All protocols'}")
        print(f"Capture limit: {count if count > 0 else 'Unlimited'}")
        print("Press Ctrl+C to stop\n")
        
        sock = self.create_socket()
        
        try:
            while True:
                if count > 0 and self.packet_count >= count:
                    break
                
                # Receive packet
                raw_data, addr = sock.recvfrom(65535)
                
                packet_info = {}
                
                # Parse based on platform
                if sys.platform.startswith('win'):
                    # Windows - data starts with IP header
                    ip_data = raw_data
                else:
                    # Linux - parse Ethernet header first
                    eth_info = self.parse_ethernet_header(raw_data)
                    packet_info['ethernet'] = eth_info
                    
                    # Check if it's IP packet (EtherType 0x0800)
                    if eth_info['type'] != 0x0800:
                        continue
                    
                    ip_data = eth_info['data']
                
                # Parse IP header
                ip_info = self.parse_ip_header(ip_data)
                packet_info['ip'] = ip_info
                
                # Apply protocol filter
                if not self.should_capture_packet(ip_info['protocol']):
                    continue
                
                # Parse protocol-specific headers
                if ip_info['protocol'] == 6:  # TCP
                    tcp_info = self.parse_tcp_header(ip_info['data'])
                    packet_info['tcp'] = tcp_info
                    packet_info['payload'] = tcp_info['data']
                    
                elif ip_info['protocol'] == 17:  # UDP
                    udp_info = self.parse_udp_header(ip_info['data'])
                    packet_info['udp'] = udp_info
                    packet_info['payload'] = udp_info['data']
                    
                elif ip_info['protocol'] == 1:  # ICMP
                    icmp_info = self.parse_icmp_header(ip_info['data'])
                    packet_info['icmp'] = icmp_info
                    packet_info['payload'] = icmp_info['data']
                else:
                    packet_info['payload'] = ip_info['data']
                
                self.packet_count += 1
                self.display_packet_info(packet_info)
                
        except KeyboardInterrupt:
            print(f"\n\nCapture stopped by user.")
        except Exception as e:
            print(f"\nError during packet capture: {e}")
        finally:
            sock.close()
            self.print_summary()
    
    def print_summary(self):
        """Print capture summary."""
        duration = time.time() - self.start_time
        print(f"\n{'='*80}")
        print("CAPTURE SUMMARY")
        print(f"{'='*80}")
        print(f"Packets captured: {self.packet_count}")
        print(f"Duration: {duration:.2f} seconds")
        print(f"Average rate: {self.packet_count/duration:.2f} packets/second")
        print("\nRemember: Use this tool responsibly and only on authorized networks!")


def main():
    """Main function with command line interface."""
    parser = argparse.ArgumentParser(
        description="Educational Network Packet Sniffer",
        epilog="IMPORTANT: Only use on networks you own or have permission to monitor!"
    )
    
    parser.add_argument('-c', '--count', type=int, default=0,
                       help='Number of packets to capture (0 for unlimited)')
    parser.add_argument('-p', '--protocol', choices=['tcp', 'udp', 'icmp'],
                       help='Filter by protocol')
    parser.add_argument('-i', '--interface',
                       help='Network interface to use (optional)')
    
    args = parser.parse_args()
    
    # Display ethical warning
    print("ETHICAL USE WARNING:")
    print("This tool is for educational purposes only!")
    print("Only use on networks you own or have explicit permission to monitor.")
    print("Unauthorized network monitoring may be illegal.")
    
    response = input("\nDo you have permission to monitor this network? (yes/no): ")
    if response.lower() != 'yes':
        print("Exiting. Only use this tool on authorized networks.")
        sys.exit(0)
    
    # Create and run sniffer
    sniffer = PacketSniffer(
        interface=args.interface,
        filter_protocol=args.protocol
    )
    
    sniffer.sniff_packets(count=args.count)


if __name__ == "__main__":
    main()