In [2]:
#!/usr/bin/env python3
"""
generate_slm_dataset.py

Fixed CLI-parsing so it is safe to run inside Jupyter kernels (uses parse_known_args()).
Usage (terminal):
    python generate_slm_dataset.py --seed 123 --n-hosts 100 --days 15 --compromised 40 \
        --avg-events 120 --out /tmp/slm_events.csv --meta /tmp/slm_meta.json

Or in a notebook:
    # Option A: run as a script (will ignore kernel args)
    !python generate_slm_dataset.py --seed 123 --out /tmp/slm.csv --meta /tmp/slm_meta.json

    # Option B: import the file and call generate_dataset(...) directly from the notebook
"""
import argparse
import csv
import json
import math
import random
import uuid
import sys
from datetime import datetime, timedelta, timezone

# ----------------------------
# Templates & helpers
# ----------------------------
PROCESS_TEMPLATES = {
    "workstation": ["explorer.exe","chrome.exe","outlook.exe","teams.exe","powershell.exe","notepad.exe","svchost.exe","onedrive.exe"],
    "server": ["svchost.exe","sqlservr.exe","backup.exe","robocopy.exe","nginx.exe","powershell.exe","elasticd.exe"],
    "db": ["mysqld.exe","postgres.exe","sqlservr.exe","backup.exe","dbclient.exe","odbcsvc.exe"],
    "admin": ["powershell.exe","psexec.exe","wmic.exe","mimikatz.exe","cmd.exe","putty.exe","plink.exe"]
}
EVENT_TYPES = ["auth","exec","conn","file","registry","dns","http"]
PROTOCOLS = ["SMB","RDP","WMI","HTTP","DNS","TCP","ICMP"]

def iso(ts):
    return ts.replace(microsecond=0).astimezone(timezone.utc).isoformat().replace("+00:00","Z")

# ----------------------------
# Core generator (same logic as before)
# ----------------------------
def generate_dataset(seed=42,
                     n_hosts=100,
                     days=15,
                     start_date=datetime(2025,10,1,tzinfo=timezone.utc),
                     compromised_count=40,
                     avg_events_per_host_day=120,
                     out_csv="slm_synthetic_events.csv",
                     out_meta="slm_metadata.json"):
    random.seed(seed)

    host_ids = [f"Host_{i+1}" for i in range(n_hosts)]
    host_uuid = {h: str(uuid.uuid4()) for h in host_ids}
    roles = (["workstation"] * max(1, int(n_hosts * 0.7)) +
             ["server"] * max(1, int(n_hosts * 0.15)) +
             ["db"] * max(1, int(n_hosts * 0.08)) +
             ["admin"] * max(1, int(n_hosts * 0.07)))
    roles = roles[:n_hosts]
    random.shuffle(roles)
    host_role = {h: roles[i] for i,h in enumerate(host_ids)}

    host_profile = {}
    user_pool = [f"user_{i+1}" for i in range(200)]
    admin_pool = [f"admin_{i+1}" for i in range(20)]
    user_pool.extend(admin_pool)
    for i,h in enumerate(host_ids):
        role = host_role[h]
        typical = int(avg_events_per_host_day * (1.0 + random.uniform(-0.25,0.25)) * (1.25 if role in ["server","db"] else 1.0) * (1.1 if role=="admin" else 1.0))
        host_profile[h] = {
            "role": role,
            "hostname": f"host-{i+1}.corp.local",
            "host_id_unique": host_uuid[h],
            "typical_events_per_day": max(5, typical),
            "processes": list(set(PROCESS_TEMPLATES[role] + random.sample(sum(PROCESS_TEMPLATES.values(), []), k=3))),
            "users": random.sample(user_pool, k=6 if role=="workstation" else 10),
            "normal_hours": (9,18) if role!="admin" else (8,20),
            "index": i+1
        }

    compromised = sorted(random.sample(host_ids, compromised_count))
    meta = {"seed": seed, "compromised_hosts": {}, "generated_on": iso(datetime.now(timezone.utc))}
    for h in compromised:
        d0 = start_date
        recon_w = (d0 + timedelta(days=0), d0 + timedelta(days=2, hours=23, minutes=59))
        cred_w  = (d0 + timedelta(days=3), d0 + timedelta(days=5, hours=23, minutes=59))
        lat_w   = (d0 + timedelta(days=6), d0 + timedelta(days=9, hours=23, minutes=59))
        pers_w  = (d0 + timedelta(days=10), d0 + timedelta(days=days-1, hours=23, minutes=59))
        meta["compromised_hosts"][h] = {
            "recon_window": (iso(recon_w[0]), iso(recon_w[1])),
            "credential_access_window": (iso(cred_w[0]), iso(cred_w[1])),
            "lateral_movement_window": (iso(lat_w[0]), iso(lat_w[1])),
            "persistence_window": (iso(pers_w[0]), iso(pers_w[1])),
        }

    per_host_totals = {}
    for h in host_ids:
        base = host_profile[h]["typical_events_per_day"] * days
        per_host_totals[h] = max(5, int(random.gauss(base, base*0.12)))

    target_min, target_max = 150000, 250000
    current_total = sum(per_host_totals.values())
    target = max(target_min, min(target_max, int(n_hosts * avg_events_per_host_day * days)))
    scale = target / current_total if current_total>0 else 1.0
    for h in per_host_totals:
        per_host_totals[h] = max(1, int(per_host_totals[h] * scale))

    def gen_timestamps_for_host(h, total):
        start = start_date
        timestamps = []
        base_day = host_profile[h]["typical_events_per_day"]
        per_day = []
        for d in range(days):
            mult = 1.0 + 0.25 * math.sin((d/3.0) + random.random())
            if random.random() < 0.02 and host_profile[h]["role"] in ("server","db"):
                mult *= (1.0 + random.uniform(0.8,2.0))
            count = max(1, int(random.gauss(base_day * mult, base_day * 0.15)))
            per_day.append(count)
        s = sum(per_day) or 1
        per_day = [max(1, int(p * (total / s))) for p in per_day]
        for d, cnt in enumerate(per_day):
            day_start = start + timedelta(days=d)
            hstart, hend = host_profile[h]["normal_hours"]
            for _ in range(cnt):
                if random.random() < 0.8:
                    hour = random.randint(hstart, max(hstart, hend-1))
                else:
                    hour = random.choices(range(24), weights=[1 if (i>=0 and i<6) else (2 if (i>=9 and i<18) else 1) for i in range(24)])[0]
                minute = random.randint(0,59)
                second = random.randint(0,59)
                ts = datetime(day_start.year, day_start.month, day_start.day, hour, minute, second, tzinfo=timezone.utc)
                ts = ts + timedelta(seconds=random.randint(0,59))
                timestamps.append(ts)
        timestamps.sort()
        if len(timestamps) > total:
            timestamps = timestamps[:total]
        while len(timestamps) < total:
            if len(timestamps) < 2:
                timestamps.append(start + timedelta(days=random.randint(0,days-1), hours=random.randint(0,23),
                                                    minutes=random.randint(0,59), seconds=random.randint(0,59)))
            else:
                i = random.randint(0, len(timestamps)-2)
                a, b = timestamps[i], timestamps[i+1]
                mid = a + (b - a) / 2
                timestamps.insert(i+1, mid)
        for i in range(1, len(timestamps)):
            if timestamps[i] <= timestamps[i-1]:
                timestamps[i] = timestamps[i-1] + timedelta(seconds=1)
        return timestamps

    def pick_dst_internal(h):
        choices = [x for x in host_ids if x!=h]
        dst = random.choice(choices)
        idx = (int(dst.split("_")[1]) if "_" in dst else 1)
        ip = f"10.0.{idx//256}.{idx%256}"
        return ip, dst

    rows = []
    for h in host_ids:
        total_events = per_host_totals[h]
        timestamps = gen_timestamps_for_host(h, total_events)
        role = host_profile[h]["role"]
        processes = host_profile[h]["processes"]
        users = host_profile[h]["users"]
        src_ip = f"10.0.{host_profile[h]['index']//256}.{host_profile[h]['index']%256}"
        is_comp = h in compromised
        comp_win = meta["compromised_hosts"].get(h)

        session_id = f"{h}-sess-{uuid.uuid4().hex[:8]}"
        for ts in timestamps:
            r = random.random()
            if r < 0.18:
                et = "auth"
            elif r < 0.40:
                et = "conn"
            elif r < 0.60:
                et = "file"
            elif r < 0.70:
                et = "exec"
            elif r < 0.80:
                et = "registry"
            elif r < 0.92:
                et = "dns"
            else:
                et = "http"

            proc = random.choice(processes)
            parent = random.choice(processes) if random.random() < 0.3 else None
            user = random.choice(users)
            dst_ip = None
            dst_host = None
            auth_result = "none"
            auth_type = None
            logon_type = None
            protocol = None
            port = None
            bytes_in = 0
            bytes_out = 0
            zone_dst = None
            attack_stage = "normal"
            is_anom = False
            is_slm = False

            if et == "auth":
                auth_result = "success" if random.random() < 0.985 else "fail"
                auth_type = random.choice(["Kerberos","NTLM","RDP","local"])
                logon_type = random.choice([2,3,10])
                protocol = random.choice(["RDP","SMB","WMI"])
                port = 3389 if protocol=="RDP" else (445 if protocol=="SMB" else None)
                bytes_in = random.randint(0,2000)
                bytes_out = random.randint(0,2000)
            elif et == "conn":
                dst_ip, dst_host = pick_dst_internal(h)
                protocol = random.choice(PROTOCOLS)
                port = random.choice([80,443,135,445,5985,22,3389,139])
                if random.random() < 0.02 and role in ("server","db"):
                    bytes_in = random.randint(5_000_000, 200_000_000)
                    bytes_out = random.randint(5_000_000, 200_000_000)
                else:
                    bytes_in = random.randint(100,5_000_000)
                    bytes_out = random.randint(100,5_000_000)
            elif et == "file":
                proc = random.choice(processes + ["robocopy.exe","powershell.exe","curl.exe","rclone.exe"])
                dst_ip, dst_host = pick_dst_internal(h)
                protocol = "SMB"
                port = 445
                if random.random() < 0.03 and role in ("server","db"):
                    bytes_in = random.randint(50_000_000,500_000_000)
                    bytes_out = random.randint(50_000_000,500_000_000)
                else:
                    bytes_in = random.randint(1_000,5_000_000)
                    bytes_out = random.randint(1_000,5_000_000)
            elif et == "exec":
                proc = random.choice(processes + ["wmic.exe","psexec.exe","powershell.exe","cmd.exe"])
                parent = random.choice(processes + ["services.exe","svchost.exe"])
                if random.random() < 0.25:
                    dst_ip, dst_host = pick_dst_internal(h)
                protocol = random.choice(["WMI","SMB",None])
                bytes_in = random.randint(0,20000)
                bytes_out = random.randint(0,20000)
            elif et == "registry":
                proc = random.choice(processes + ["reg.exe","powershell.exe"])
                bytes_in = random.randint(0,2000)
                bytes_out = random.randint(0,2000)
            elif et == "dns":
                proc = random.choice(processes + ["dnsmasq","systemd-resolved"])
                protocol = "DNS"
                port = 53
                dst_ip = f"8.8.{random.randint(0,255)}.{random.randint(0,255)}"
                bytes_in = random.randint(50,2000)
                bytes_out = random.randint(50,2000)
            elif et == "http":
                proc = random.choice(processes + ["chrome.exe","curl.exe","wget.exe","python.exe"])
                protocol = "HTTP"
                port = random.choice([80,443])
                dst_ip = f"13.{random.randint(0,255)}.{random.randint(0,255)}.{random.randint(1,254)}"
                bytes_in = random.randint(500,10_000_000)
                bytes_out = random.randint(500,10_000_000)

            if random.random() < 0.012 and not is_comp:
                is_anom = True
                if et == "file":
                    bytes_out = random.randint(50_000_000,200_000_000)
                if et == "auth":
                    auth_result = "fail"
                    auth_type = random.choice(["local","NTLM"])

            if is_comp and comp_win:
                recon_start = datetime.fromisoformat(comp_win["recon_window"][0].replace("Z","+00:00"))
                recon_end   = datetime.fromisoformat(comp_win["recon_window"][1].replace("Z","+00:00"))
                cred_start  = datetime.fromisoformat(comp_win["credential_access_window"][0].replace("Z","+00:00"))
                cred_end    = datetime.fromisoformat(comp_win["credential_access_window"][1].replace("Z","+00:00"))
                lat_start   = datetime.fromisoformat(comp_win["lateral_movement_window"][0].replace("Z","+00:00"))
                lat_end     = datetime.fromisoformat(comp_win["lateral_movement_window"][1].replace("Z","+00:00"))
                pers_start  = datetime.fromisoformat(comp_win["persistence_window"][0].replace("Z","+00:00"))
                pers_end    = datetime.fromisoformat(comp_win["persistence_window"][1].replace("Z","+00:00"))

                if recon_start <= ts <= recon_end:
                    if random.random() < 0.07:
                        et = random.choice(["conn","dns","exec"])
                        proc = random.choice(["nbtstat.exe","net.exe","netstat.exe","powershell.exe","arp.exe"])
                        attack_stage = "reconnaissance"
                        is_anom = random.random() < 0.45
                        if et == "conn":
                            dst_ip, dst_host = pick_dst_internal(h)
                            bytes_in = random.randint(20,5000)
                            bytes_out = random.randint(20,5000)
                elif cred_start <= ts <= cred_end:
                    if random.random() < 0.06:
                        et = random.choice(["exec","auth","registry"])
                        proc = random.choice(["mimikatz.exe","procdump.exe","rundll32.exe","lsass.exe","powershell.exe"])
                        attack_stage = "credential_access"
                        is_anom = True
                        if et == "auth":
                            dst_ip, dst_host = pick_dst_internal(h)
                            auth_result = "fail" if random.random() < 0.86 else "success"
                            auth_type = random.choice(["NTLM","Kerberos"])
                            logon_type = random.choice([3,10])
                        else:
                            parent = "svchost.exe"
                            if proc == "procdump.exe":
                                cmdline = "procdump -ma lsass.exe lsass.dmp"
                            elif proc == "mimikatz.exe":
                                cmdline = "mimikatz.exe \"privilege::debug\" \"sekurlsa::logonpasswords\""
                            else:
                                cmdline = ""
                elif lat_start <= ts <= lat_end:
                    if random.random() < 0.035:
                        et = random.choice(["exec","conn","file"])
                        attack_stage = "lateral_movement"
                        is_anom = True
                        is_slm = True
                        if et == "exec":
                            proc = random.choice(["psexec.exe","wmic.exe","powershell.exe","winrm.vbs"])
                            parent = "services.exe"
                            dst_ip, dst_host = pick_dst_internal(h)
                            cmdline = f"{proc} \\\\{dst_ip} -accepteula -c \"powershell -EncodedCommand {uuid.uuid4().hex[:24]}\""
                            protocol = random.choice(["SMB","WMI","RDP"])
                            port = 445 if protocol=="SMB" else (5985 if protocol=="WMI" else 3389)
                        elif et == "conn":
                            dst_ip, dst_host = pick_dst_internal(h)
                            protocol = random.choice(["SMB","RDP","WMI"])
                            port = random.choice([445,3389,5985])
                            bytes_in = random.randint(100,50000)
                            bytes_out = random.randint(100,50000)
                        elif et == "file":
                            proc = random.choice(["robocopy.exe","powershell.exe"])
                            dst_ip, dst_host = pick_dst_internal(h)
                            bytes_in = random.randint(5_000_000,50_000_000)
                            bytes_out = random.randint(5_000_000,50_000_000)
                elif pers_start <= ts <= pers_end:
                    if random.random() < 0.035:
                        et = random.choice(["registry","file","exec","conn"])
                        attack_stage = "persistence"
                        is_anom = True
                        is_slm = True
                        if et == "registry":
                            proc = "reg.exe"
                            parent = "powershell.exe"
                        elif et == "exec":
                            proc = "schtasks.exe"
                            parent = "services.exe"
                        elif et == "file":
                            staging_ip, staging_host = pick_dst_internal(h)
                            dst_ip, dst_host = staging_ip, staging_host
                            bytes_out = random.randint(10_000_000,200_000_000)
                            bytes_in = random.randint(0,5_000_000)
                        elif et == "conn":
                            dst_ip, dst_host = pick_dst_internal(h)
                            protocol = random.choice(["HTTP","SMB"])
                            port = 80 if protocol=="HTTP" else 445

            # set default cmdline if not injected by attack overlay
            if 'cmdline' in locals() and isinstance(locals().get('cmdline'), str) and locals().get('cmdline'):
                cmdline = locals().get('cmdline')
                try:
                    del locals()['cmdline']
                except Exception:
                    pass
            else:
                if et == "exec":
                    cmdline = random.choice([
                        f"{proc} /c do-something",
                        f"{proc} -NoProfile -ExecutionPolicy Bypass -Command Get-ChildItem",
                        f"{proc} -EncodedCommand {uuid.uuid4().hex[:16]}",
                        f"{proc} /s /q C:\\temp\\{random.randint(1,999)}"
                    ])
                elif et == "file":
                    if dst_ip:
                        cmdline = f"{proc} \\\\{dst_ip}\\share\\file{random.randint(1,500)}.dat"
                    else:
                        cmdline = f"{proc} C:\\temp\\file{random.randint(1,500)}.dat"
                elif et == "auth":
                    cmdline = ""
                elif et == "dns":
                    cmdline = f"query domain{random.randint(1,400)}.internal"
                elif et == "http":
                    cmdline = f"GET /api/{random.randint(1,2000)}"
                elif et == "registry":
                    cmdline = "reg query HKLM\\Software"
                else:
                    cmdline = ""

            privilege = "admin" if (random.random() < 0.03 or role == "admin" and random.random()<0.8) else "user"
            zd = "INTERNAL" if dst_ip and dst_ip.startswith("10.") else ("EXTERNAL" if dst_ip else None)

            rec = {
                "timestamp": iso(ts),
                "host_id": h,
                "host_hostname": host_profile[h]["hostname"],
                "host_id_unique": host_profile[h]["host_id_unique"],
                "user_id": user,
                "src_ip": src_ip,
                "dst_ip": dst_ip,
                "dst_host": dst_host,
                "event_type": et,
                "process_name": proc,
                "parent_process": parent,
                "command_line_args": cmdline,
                "session_id": session_id,
                "auth_result": auth_result,
                "auth_type": auth_type,
                "logon_type": logon_type,
                "protocol": protocol,
                "port": port,
                "bytes_in": bytes_in,
                "bytes_out": bytes_out,
                "privilege_level": privilege,
                "zone_src": "INTERNAL",
                "zone_dst": zd,
                "attack_stage": attack_stage,
                "is_anomaly_event": int(bool(is_anom)),
                "is_stealth_lateral_movement": int(bool(is_slm))
            }
            rows.append(rec)

    # enforce anomaly rate rule (< 5%); clamp down if necessary
    total = len(rows)
    anomalies = sum(r["is_anomaly_event"] for r in rows)
    slm_total = sum(r["is_stealth_lateral_movement"] for r in rows)
    global_anom_rate = anomalies / max(1,total)
    if global_anom_rate > 0.05:
        target_rate = 0.03
        to_flip = int((global_anom_rate - target_rate) * total)
        cand_idxs = [i for i,r in enumerate(rows) if r["is_anomaly_event"]==1 and r["is_stealth_lateral_movement"]==0]
        random.shuffle(cand_idxs)
        for i in cand_idxs[:to_flip]:
            rows[i]["is_anomaly_event"] = 0
        anomalies = sum(r["is_anomaly_event"] for r in rows)
        global_anom_rate = anomalies / max(1,total)

    header = [
        "timestamp","host_id","host_hostname","host_id_unique","user_id","src_ip","dst_ip","dst_host",
        "event_type","process_name","parent_process","command_line_args","session_id","auth_result","auth_type",
        "logon_type","protocol","port","bytes_in","bytes_out","privilege_level","zone_src","zone_dst","attack_stage",
        "is_anomaly_event","is_stealth_lateral_movement"
    ]
    with open(out_csv, "w", newline='', encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=header, quoting=csv.QUOTE_MINIMAL)
        writer.writeheader()
        for r in rows:
            writer.writerow({k: r.get(k, None) for k in header})

    meta_summary = {
        "seed": seed,
        "n_hosts": n_hosts,
        "days": days,
        "start_date": iso(start_date),
        "compromised_count": compromised_count,
        "compromised_hosts": meta["compromised_hosts"],
        "total_events": total,
        "total_anomalies": anomalies,
        "total_stealth_events": slm_total,
        "anomaly_rate": global_anom_rate
    }
    with open(out_meta, "w", encoding="utf-8") as f:
        json.dump(meta_summary, f, indent=2)

    print("Generation complete.")
    print(f" CSV: {out_csv}")
    print(f" META: {out_meta}")
    print(f" Total events: {total:,}")
    print(f" Total anomalies: {anomalies:,} (rate {global_anom_rate:.3%})")
    print(f" Total stealth-lateral events: {slm_total:,}")

# ----------------------------
# CLI: use parse_known_args() so Jupyter/kernel '--f=...' won't crash it
# ----------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate synthetic SLM event dataset")
    parser.add_argument("--seed", type=int, default=42, help="RNG seed (set to change dataset reproducibly)")
    parser.add_argument("--n-hosts", type=int, default=100)
    parser.add_argument("--days", type=int, default=15)
    parser.add_argument("--start-date", type=str, default="2025-10-01T00:00:00Z", help="ISO8601 UTC start date")
    parser.add_argument("--compromised", type=int, default=40, help="Number of compromised hosts")
    parser.add_argument("--avg-events", type=int, default=120, help="Average events per host per day")
    parser.add_argument("--out", type=str, default="slm_train_dataset.csv", help="Output CSV path")
    parser.add_argument("--meta", type=str, default="slm_metadata.json", help="Output metadata JSON path")

    # IMPORTANT: use parse_known_args so extra kernel args are ignored when running in Jupyter
    args, unknown = parser.parse_known_args()
    # parse start date into datetime
    try:
        sd = datetime.fromisoformat(args.start_date.replace("Z","+00:00")).astimezone(timezone.utc)
    except Exception:
        sd = datetime(2025,10,1,tzinfo=timezone.utc)

    generate_dataset(seed=args.seed,
                     n_hosts=args.n_hosts,
                     days=args.days,
                     start_date=sd,
                     compromised_count=args.compromised,
                     avg_events_per_host_day=args.avg_events,
                     out_csv=args.out,
                     out_meta=args.meta)

Generation complete.
 CSV: slm_train_dataset.csv
 META: slm_metadata.json
 Total events: 179,951
 Total anomalies: 4,205 (rate 2.337%)
 Total stealth-lateral events: 1,343
