In [2]:
import os
import sys
import numpy as np
import pandas as pd
from tqdm import tqdm
from scapy.all import rdpcap, Dot11
from multiprocessing import Pool, cpu_count
import gc

# ==========================================================
# CONFIGURATION
# ==========================================================
CAPTURE_ROOT = "capturedfiles"
OUTPUT_CSV = "adaptive_thresold.csv"

DEVICE_MACS = {
    "plug": "c0:f8:53:de:cf:2a",
    "plug1": "c0:f8:53:df:18:ea",
    "wall_socket": "d8:d6:68:06:6d:65",
    "tabel_lamp": "3c:0b:59:8f:25:42",
    "switch": "38:2c:e5:1d:02:fb",
    "switch1": "38:2c:e5:1c:cf:6e",
    "air_purifier": "50:ec:50:94:7b:a3",
    "motion_sensor": "f8:17:2d:b6:38:de",
    "motion_sensor1": "f8:17:2d:b4:3d:5a",
    "door_sensor": "18:de:50:54:8e:e9",
    "door_sensor1": "18:de:50:50:39:37",
    "baby_cam": "78:8b:2a:9c:80:1e",
    "camera": "5c:4e:ee:ce:f8:3b",
    "power_strip": "fc:3c:d7:53:f6:79",
    "non_iot_device": "4e:c2:19:38:75:dd",
    "non_iot_device1": "0c:e4:a0:e5:05:03",
    "non_iot_device2": "d8:f3:bc:67:1c:a5",
    "non_iot_device3": "b4:0e:de:38:64:95",
    "non_iot_device4": "66:35:9b:0c:1c:c6", 
}

BSSID = "14:eb:b6:be:d7:1e"

# ==========================================================
# OPTIMIZED FILE PROCESSING
# ==========================================================
def find_all_cap_files(root_dir):
    """Find all .cap files efficiently"""
    cap_files = []
    for root, _, files in os.walk(root_dir):
        for f in files:
            if f.endswith(".cap"):
                cap_files.append(os.path.join(root, f))
    return cap_files

def extract_all_devices_packets(packets):
    """Extract packets for all devices in one pass"""
    device_packets = {mac: [] for mac in DEVICE_MACS.values()}
    
    for pkt in packets:
        if not pkt.haslayer(Dot11):
            continue

        dot11 = pkt[Dot11]
        ptype = getattr(dot11, "type", None)
        psub = getattr(dot11, "subtype", None)

        # Skip noisy frames
        if ptype == 2 and psub in [4, 12]:
            continue

        src = dot11.addr2
        dst = dot11.addr1

        if src is None or dst is None:
            continue

        # Check if packet involves any of our devices
        if src in DEVICE_MACS.values() and dst == BSSID:
            if src in device_packets:
                device_packets[src].append((float(pkt.time), len(pkt), pkt))
        elif src == BSSID and dst in DEVICE_MACS.values():
            if dst in device_packets:
                device_packets[dst].append((float(pkt.time), len(pkt), pkt))
    
    # Sort packets by time for each device
    for mac in device_packets:
        device_packets[mac].sort(key=lambda x: x[0])
    
    return device_packets

def collect_device_iat_stats_optimized(root_dir, sample_size=50):
    """Collect IAT stats using sampling to avoid processing all files"""
    cap_files = find_all_cap_files(root_dir)
    device_iats = {dev: [] for dev in DEVICE_MACS.keys()}
    
    # Use only a sample of files for IAT calculation
    if len(cap_files) > sample_size:
        sample_files = np.random.choice(cap_files, sample_size, replace=False)
    else:
        sample_files = cap_files
    
    print(f"Collecting IAT stats from {len(sample_files)} sample files...")
    
    for cap_fp in tqdm(sample_files, desc="Sampling IATs"):
        try:
            # Use count parameter to limit packet reading for large files
            packets = rdpcap(cap_fp, count=10000)  # Limit to first 5000 packets
        except Exception as e:
            print(f"Failed to read {cap_fp}: {e}", file=sys.stderr)
            continue
            
        device_packets = extract_all_devices_packets(packets)
        
        for dev_name, mac in DEVICE_MACS.items():
            time_pkt = device_packets.get(mac, [])
            if len(time_pkt) < 2:
                continue
            times = np.array([t for t, _, _ in time_pkt])
            iat = np.diff(times)
            device_iats[dev_name].extend(iat.tolist())
        
        # Clear memory
        del packets
        gc.collect()
    
    device_stats = {}
    for dev, iats in device_iats.items():
        if not iats:
            device_stats[dev] = {"median": 0.5, "mean": 0.5, "std": 0.1}
        else:
            iats = np.array(iats)
            device_stats[dev] = {
                "median": float(np.median(iats)),
                "mean": float(np.mean(iats)),
                "std": float(np.std(iats)),
            }
    return device_stats

# ==========================================================
# FIXED FEATURE ENGINEERING FUNCTION
# ==========================================================
def featurize_burst(burst, mac, filename, burst_id):
    """Fixed version with proper return statement"""
    try:
        import numpy as np
        from math import log2

        def entropy(values):
            if len(values) == 0:
                return 0.0
            vals, counts = np.unique(values, return_counts=True)
            probs = counts / counts.sum()
            return -np.sum(probs * np.log2(probs + 1e-12))

        def safe_div(a, b):
            return a / (b + 1e-12)

        def skewness(x):
            x = np.asarray(x)
            if x.size < 2:
                return 0.0
            m = x.mean()
            s = x.std(ddof=0) + 1e-12
            return np.mean((x - m) ** 3) / (s ** 3)

        def kurtosis_excess(x):
            x = np.asarray(x)
            if x.size < 2:
                return 0.0
            m = x.mean()
            s = x.std(ddof=0) + 1e-12
            return np.mean((x - m) ** 4) / (s ** 4) - 3.0

        def hurst_rs(ts):
            ts = np.asarray(ts)
            n = ts.size
            if n < 10:
                return 0.5
            lags = np.floor(np.logspace(1, np.log10(n/2), num=10)).astype(int)
            rs = []
            for L in np.unique(lags):
                segment = ts[:L]
                if segment.size < 2:
                    continue
                mean_seg = segment.mean()
                cum = np.cumsum(segment - mean_seg)
                R = cum.max() - cum.min()
                S = segment.std(ddof=0) + 1e-12
                rs.append(R / S)
            if len(rs) < 2:
                return 0.5
            x = np.log(np.arange(1, len(rs)+1))
            y = np.log(rs)
            slope = np.polyfit(x, y, 1)[0]
            return max(0.0, min(1.0, slope))

        times = np.array([t for t, _, _ in burst])
        sizes = np.array([s for _, s, _ in burst])

        inter_arrival = np.diff(times) if len(times) > 1 else np.array([0.0])
        
        # Calculate basic statistics once
        total_pkts = len(burst)
        total_bytes = np.sum(sizes)
        flow_duration = times[-1] - times[0] if len(times) > 1 else 0.0
        mean_pkt_size = float(np.mean(sizes)) if sizes.size else 0.0
        std_pkt_size = float(np.std(sizes)) if sizes.size else 0.0
        min_pkt_size = float(np.min(sizes)) if sizes.size else 0.0
        max_pkt_size = float(np.max(sizes)) if sizes.size else 0.0
        mean_iat = float(np.mean(inter_arrival))
        std_iat = float(np.std(inter_arrival))
        min_iat = float(np.min(inter_arrival))
        max_iat = float(np.max(inter_arrival))
        throughput = float(total_bytes / (flow_duration + 1e-6))

        feat = {
            "file": os.path.basename(filename),
            "device_mac": mac,
            "burst_id": burst_id,
            "packet_count": total_pkts,
            "duration": flow_duration,
            "mean_pkt_size": mean_pkt_size,
            "std_pkt_size": std_pkt_size,
            "min_pkt_size": min_pkt_size,
            "max_pkt_size": max_pkt_size,
            "mean_iat": mean_iat,
            "std_iat": std_iat,
            "min_iat": min_iat,
            "max_iat": max_iat,
            "throughput": throughput,
            "burst_start_time": times[0],
            "burst_end_time": times[-1],
        }

        # Packet classification and counting
        fwd_sizes, bwd_sizes = [], []
        control_count = management_count = data_count = 0
        null_count = probe_req_count = probe_resp_count = 0
        ack_count = action_packet_count = block_ack_packet_count = 0
        max_consec_qos_data = 0
        current_qos_streak = 0
        seq_numbers = []
        frag_flag_count = 0
        qos_count = 0
        retry_flags = 0
        to_ds = from_ds = 0

        for t, s, pkt in burst:
            dot11 = pkt[Dot11]
            src, dst = dot11.addr2, dot11.addr1

            if src == mac:
                fwd_sizes.append(s)
            elif dst == mac:
                bwd_sizes.append(s)

            pkt_type = getattr(dot11, "type", None)
            pkt_subtype = getattr(dot11, "subtype", None)
            fc = getattr(dot11, "FCfield", 0)
            
            if fc & 0x08:
                retry_flags += 1
                
            frag = getattr(dot11, "frag", 0)
            if frag != 0:
                frag_flag_count += 1
                
            seq = getattr(dot11, "SC", None)
            if seq is not None:
                try:
                    seqnum = (int(seq) >> 4) & 0xFFF
                    seq_numbers.append(seqnum)
                except Exception:
                    pass
                    
            if pkt_type == 2 and pkt_subtype in range(8,16):
                qos_count += 1
                
            if isinstance(fc, int):
                if fc & 0x01: to_ds += 1
                if fc & 0x02: from_ds += 1

            if pkt_type == 0:  # Management
                management_count += 1
                if pkt_subtype == 4:
                    probe_req_count += 1
                elif pkt_subtype == 5:
                    probe_resp_count += 1
                elif pkt_subtype == 13:
                    action_packet_count += 1
                current_qos_streak = 0
            elif pkt_type == 1:  # Control
                control_count += 1
                if pkt_subtype == 13:
                    ack_count += 1
                elif pkt_subtype == 9:
                    block_ack_packet_count += 1
                current_qos_streak = 0
            elif pkt_type == 2:  # Data
                data_count += 1
                if pkt_subtype in range(8,16):
                    current_qos_streak += 1
                    max_consec_qos_data = max(max_consec_qos_data, current_qos_streak)
                else:
                    current_qos_streak = 0
                if pkt_subtype in [0x4,0xC]:
                    null_count += 1

        # Calculate derived features
        uplink_packet_ratio = len(fwd_sizes) / (total_pkts + 1e-6)
        downlink_packet_ratio = len(bwd_sizes) / (total_pkts + 1e-6)
        ack_to_data_ratio = safe_div(control_count, data_count)
        entropy_pkt_size = entropy(sizes)
        burstiness_index = safe_div((max_iat - mean_iat), mean_iat) if inter_arrival.size > 1 else 0.0
        
        unique_peer_count = len(set([pkt[Dot11].addr1 for _, _, pkt in burst] + [pkt[Dot11].addr2 for _, _, pkt in burst]))
        payload_ratio = safe_div(total_bytes, (total_pkts * max_pkt_size))
        seq_entropy = entropy(seq_numbers) if seq_numbers else 0.0
        tods_ratio = safe_div(to_ds, total_pkts)
        fromds_ratio = safe_div(from_ds, total_pkts)

        # Burst behavior features
        median_iat = float(np.median(inter_arrival)) if inter_arrival.size else 0.0
        active_windows = np.sum(inter_arrival < (median_iat + 1e-12)) if inter_arrival.size else 0
        active_time_ratio = safe_div(active_windows, len(inter_arrival))
        
        mid_idx = total_pkts // 2
        first_half_sizes = sizes[:mid_idx] if total_pkts>1 else sizes
        second_half_sizes = sizes[mid_idx:] if total_pkts>1 else sizes
        first_sum = float(np.sum(first_half_sizes)) if first_half_sizes.size else 0.0
        second_sum = float(np.sum(second_half_sizes)) if second_half_sizes.size else 0.0
        burst_symmetry = safe_div(abs(first_sum-second_sum), (first_sum+second_sum))

        # Peak density calculation
        if flow_duration <= 0:
            peak_packet_density = total_pkts
        else:
            window = 1.0
            i = 0
            max_density = 0
            N = len(times)
            for start_idx in range(N):
                end_time = times[start_idx] + window
                while i < N and times[i] <= end_time:
                    i += 1
                count_window = i - start_idx
                if count_window > max_density:
                    max_density = count_window
            peak_packet_density = max_density

        # Packet arrival slope
        if flow_duration>0 and total_pkts>1:
            t_rel = times - times[0]
            cum_counts = np.arange(1,total_pkts+1)
            denom = np.var(t_rel) + 1e-12
            slope = float(np.cov(t_rel,cum_counts,bias=True)[0,1]/denom)
        else:
            slope = safe_div(total_pkts-1, flow_duration)

        # Statistical features
        iat_cv = safe_div(std_iat, mean_iat)
        iat_skew = float(skewness(inter_arrival))
        iat_kurt = float(kurtosis_excess(inter_arrival))
        iat_entropy = float(entropy(np.round(inter_arrival,6)))
        size_skew = float(skewness(sizes))
        size_kurt = float(kurtosis_excess(sizes))
        size_entropy = float(entropy(np.round(sizes,0)))

        # Hurst exponent
        try: 
            hurst = float(hurst_rs(inter_arrival))
        except Exception: 
            hurst=0.5

        # IAT histogram
        if inter_arrival.size:
            hist_counts,_ = np.histogram(inter_arrival,bins=5)
            iat_hist = (hist_counts/(hist_counts.sum()+1e-12)).tolist()
        else:
            iat_hist = [0.0]*5

        # Additional requested features
        InitBwdWinByts = float(sum(bwd_sizes[:5])) if len(bwd_sizes) else 0.0
        BwdPktLenMax = float(np.max(bwd_sizes)) if len(bwd_sizes) else 0.0
        FwdPktLenMax = float(np.max(fwd_sizes)) if len(fwd_sizes) else 0.0
        BwdPktsPerSec = safe_div(len(bwd_sizes), flow_duration)

        # Assemble final feature dictionary
        feat.update({
            "uplink_packet_ratio": float(uplink_packet_ratio),
            "downlink_packet_ratio": float(downlink_packet_ratio),
            "retry_count": int(retry_flags),
            "ack_to_data_ratio": float(ack_to_data_ratio),
            "entropy_pkt_size": float(entropy_pkt_size),
            "burstiness_index": float(burstiness_index),
            "unique_peer_count": int(unique_peer_count),
            "payload_ratio": float(payload_ratio),
            "control_packet_count": int(control_count),
            "management_packet_count": int(management_count),
            "data_packet_count": int(data_count),
            "null_packet_count": int(null_count),
            "probe_request_count": int(probe_req_count),
            "probe_response_count": int(probe_resp_count),
            "ack_packet_count": int(ack_count),
            "action_packet_count": int(action_packet_count),
            "block_ack_packet_count": int(block_ack_packet_count),
            "max_consec_qos_data_count": int(max_consec_qos_data),
            "seq_entropy": float(seq_entropy),
            "frag_flag_count": int(frag_flag_count),
            "qos_packet_count": int(qos_count),
            "tods_ratio": float(tods_ratio),
            "fromds_ratio": float(fromds_ratio),
            "retry_flags_count": int(retry_flags),
            "active_time_ratio": float(active_time_ratio),
            "burst_symmetry": float(burst_symmetry),
            "peak_packet_density_1s": int(peak_packet_density),
            "packet_arrival_slope": float(slope),
            "iat_cv": float(iat_cv),
            "iat_skew": float(iat_skew),
            "iat_kurtosis": float(iat_kurt),
            "iat_entropy": float(iat_entropy),
            "size_skew": float(size_skew),
            "size_kurtosis": float(size_kurt),
            "size_entropy": float(size_entropy),
            "hurst_exponent": float(hurst),
            "iat_hist_bin0": float(iat_hist[0]),
            "iat_hist_bin1": float(iat_hist[1]),
            "iat_hist_bin2": float(iat_hist[2]),
            "iat_hist_bin3": float(iat_hist[3]),
            "iat_hist_bin4": float(iat_hist[4]),
            "InitBwdWinByts": InitBwdWinByts,
            "FlowIATMin": min_iat,
            "FlowIATMax": max_iat,
            "BwdPktLenMax": BwdPktLenMax,
            "FlowDuration": flow_duration,
            "FlowBytesPerSec": throughput,
            "BwdPktsPerSec": BwdPktsPerSec,
            "FwdPktLenMax": FwdPktLenMax,
            "PktLenMin": min_pkt_size,
            "PktLenMax": max_pkt_size,
        })

        return feat  # MAKE SURE THIS RETURN STATEMENT EXISTS
        
    except Exception as e:
        print(f"Error in featurize_burst for {filename} burst {burst_id}: {e}")
        return None  # Return None instead of crashing

def process_single_cap_optimized(args):
    """Process a single capture file for all devices with proper error handling"""
    cap_fp, thresholds = args
    try:
        # Use count to limit very large files
        packets = rdpcap(cap_fp, count=10000)
    except Exception as e:
        print(f"Failed to read {cap_fp}: {e}", file=sys.stderr)
        return []
    
    feats = []
    device_packets = extract_all_devices_packets(packets)
    
    for dev_name, mac in DEVICE_MACS.items():
        time_pkt = device_packets.get(mac, [])
        if not time_pkt:
            continue
            
        gap_threshold = thresholds.get(dev_name, 0.5)
        bursts = segment_into_bursts(time_pkt, gap_threshold=gap_threshold)
        
        for i, burst in enumerate(bursts):
            # Skip very small bursts
            if len(burst) < 3:
                continue
                
            feat = featurize_burst(burst, mac, cap_fp, i)
            if feat is not None:  # Only add if featurization succeeded
                feat["device_name"] = dev_name
                feat["adaptive_gap"] = gap_threshold
                feats.append(feat)
    
    # Clear memory
    del packets
    gc.collect()
    
    return feats

def process_all_captures_optimized(root_dir, thresholds, max_workers=None):
    """Process all captures with optimized parallel processing"""
    cap_files = find_all_cap_files(root_dir)
    
    if max_workers is None:
        max_workers = min(cpu_count(), 8)  # Limit workers to avoid memory issues
    
    print(f"Processing {len(cap_files)} files with {max_workers} workers...")
    
    # Process in chunks to manage memory
    chunk_size = min(100, len(cap_files) // (max_workers * 2) + 1)
    
    all_feats = []
    for i in range(0, len(cap_files), chunk_size):
        chunk_files = cap_files[i:i + chunk_size]
        
        with Pool(max_workers) as pool:
            chunk_args = [(fp, thresholds) for fp in chunk_files]
            results = list(tqdm(
                pool.imap(process_single_cap_optimized, chunk_args),
                total=len(chunk_files),
                desc=f"Processing chunk {i//chunk_size + 1}/{(len(cap_files)-1)//chunk_size + 1}"
            ))
        
        chunk_feats = [feat for sublist in results for feat in sublist]
        all_feats.extend(chunk_feats)
        
        # Clear memory between chunks
        gc.collect()
    
    df = pd.DataFrame(all_feats)
    return df

def segment_into_bursts(time_pkt, gap_threshold=0.1):
    if not time_pkt:
        return []
    bursts, current = [], [time_pkt[0]]
    for i in range(1, len(time_pkt)):
        if time_pkt[i][0] - time_pkt[i - 1][0] > gap_threshold:
            bursts.append(current)
            current = [time_pkt[i]]
        else:
            current.append(time_pkt[i])
    bursts.append(current)
    return bursts

def compute_adaptive_thresholds(device_stats, factor_dict=None):
    thresholds = {}
    for dev, stats in device_stats.items():
        factor = 2.0
        if factor_dict and dev in factor_dict:
            factor = factor_dict[dev]
        thresholds[dev] = max(0.001, stats["median"] * factor)
    return thresholds

# ==========================================================
# MAIN EXECUTION
# ==========================================================
if __name__ == "__main__":
    print("Collecting IAT stats per device (optimized)...")
    device_stats = collect_device_iat_stats_optimized(CAPTURE_ROOT)
    
    print("Computing adaptive thresholds per device...")
    adaptive_thresholds = compute_adaptive_thresholds(device_stats)
    
    print("Processing captures with adaptive thresholds...")
    df = process_all_captures_optimized(CAPTURE_ROOT, adaptive_thresholds)
    
    if len(df) > 0:
        df.to_csv(OUTPUT_CSV, index=False)
        print(f"Saved {len(df)} bursts to {OUTPUT_CSV}")
    else:
        print("No features extracted - output CSV not created")
    
    print("Adaptive thresholds used per device:")
    for dev, t in adaptive_thresholds.items():
        print(f" - {dev}: {t:.4f} sec")

Collecting IAT stats per device (optimized)...
Collecting IAT stats from 50 sample files...


Sampling IATs: 100%|██████████| 50/50 [00:13<00:00,  3.77it/s]

Computing adaptive thresholds per device...
Processing captures with adaptive thresholds...
Processing 1760 files with 8 workers...



Processing chunk 1/18: 100%|██████████| 100/100 [00:04<00:00, 22.43it/s]
Processing chunk 2/18: 100%|██████████| 100/100 [00:03<00:00, 32.26it/s]
Processing chunk 3/18: 100%|██████████| 100/100 [00:03<00:00, 28.53it/s]
Processing chunk 4/18: 100%|██████████| 100/100 [00:03<00:00, 25.05it/s]
Processing chunk 5/18: 100%|██████████| 100/100 [00:06<00:00, 16.19it/s]
Processing chunk 6/18: 100%|██████████| 100/100 [00:06<00:00, 15.52it/s]
Processing chunk 7/18: 100%|██████████| 100/100 [00:07<00:00, 13.60it/s]
Processing chunk 8/18: 100%|██████████| 100/100 [00:03<00:00, 27.13it/s]
Processing chunk 9/18: 100%|██████████| 100/100 [00:03<00:00, 30.58it/s]
Processing chunk 10/18: 100%|██████████| 100/100 [00:06<00:00, 14.37it/s]
Processing chunk 11/18: 100%|██████████| 100/100 [00:07<00:00, 14.06it/s]
Processing chunk 12/18: 100%|██████████| 100/100 [00:03<00:00, 32.54it/s]
Processing chunk 13/18: 100%|██████████| 100/100 [00:01<00:00, 50.94it/s]
Processing chunk 14/18: 100%|██████████| 100/1

Saved 82013 bursts to adaptive_thresold.csv
Adaptive thresholds used per device:
 - plug: 0.0039 sec
 - plug1: 1.0000 sec
 - wall_socket: 0.0043 sec
 - tabel_lamp: 0.0010 sec
 - switch: 0.0046 sec
 - switch1: 0.0045 sec
 - air_purifier: 0.0123 sec
 - motion_sensor: 0.0030 sec
 - motion_sensor1: 1.0000 sec
 - door_sensor: 0.0039 sec
 - door_sensor1: 1.0000 sec
 - baby_cam: 0.0010 sec
 - camera: 0.0131 sec
 - power_strip: 0.0056 sec
 - non_iot_device: 0.0010 sec
 - non_iot_device1: 1.0000 sec
 - non_iot_device2: 1.0000 sec
 - non_iot_device3: 0.0028 sec
 - non_iot_device4: 1.0000 sec
